Exemplo n.º 1
0
def get_rel_bilinear_scores(entity_emb, entity_scores, num_labels, config, dropout):
  num_sentences = util.shape(entity_emb, 0)
  num_entities = util.shape(entity_emb, 1)

    
  e1_emb_expanded = tf.expand_dims(entity_emb, 2)  # [num_sents, num_ents, 1, emb]
  e2_emb_expanded = tf.expand_dims(entity_emb, 1)  # [num_sents, 1, num_ents, emb]
  e1_emb_tiled = tf.tile(e1_emb_expanded, [1, 1, num_entities, 1])  # [num_sents, num_ents, num_ents, emb]
  e2_emb_tiled = tf.tile(e2_emb_expanded, [1, num_entities, 1, 1])  # [num_sents, num_ents, num_ents, emb]
  
  bilinear_score = util.bilinear(e1_emb_tiled, e2_emb_tiled, num_labels - 1) # [num_sents, num_ents, num_ents, bilinear_dim]


  pair_emb_list = [e1_emb_tiled, e2_emb_tiled]
  pair_emb = tf.concat(pair_emb_list, 3)  # [num_sentences, num_ents, num_ents, emb]
  pair_emb_size = util.shape(pair_emb, 3)
  # flat_pair_emb = tf.reshape(pair_emb, [num_sentences * num_entities * num_entities, pair_emb_size])
  flat_rel_scores = bilinear_score
  # flat_rel_scores = get_unary_scores(flat_pair_emb, config, dropout, num_labels - 1,
  #     "relation_scores")  # [num_sentences * num_ents * num_ents, 1]

  # flat_rel_scores += bilinear_score
  rel_scores = tf.reshape(flat_rel_scores, [num_sentences, num_entities, num_entities, num_labels - 1])
  rel_scores += tf.expand_dims(tf.expand_dims(entity_scores, 2), 3) + tf.expand_dims(
      tf.expand_dims(entity_scores, 1), 3)  # [num_sentences, ents, max_num_ents, num_labels-1]
  
  dummy_scores = tf.zeros([num_sentences, num_entities, num_entities, 1], tf.float32)
  rel_scores = tf.concat([dummy_scores, rel_scores], 3)  # [num_sentences, max_num_ents, max_num_ents, num_labels] 
  return rel_scores  # [num_sentences, num_entities, num_entities, num_labels]
Exemplo n.º 2
0
def get_rel_nonzero_scores(entity_emb, entity_scores, num_labels, config, dropout):
  num_sentences = util.shape(entity_emb, 0)
  num_entities = util.shape(entity_emb, 1)

    
  e1_emb_expanded = tf.expand_dims(entity_emb, 2)  # [num_sents, num_ents, 1, emb]
  e2_emb_expanded = tf.expand_dims(entity_emb, 1)  # [num_sents, 1, num_ents, emb]
  e1_emb_tiled = tf.tile(e1_emb_expanded, [1, 1, num_entities, 1])  # [num_sents, num_ents, num_ents, emb]
  e2_emb_tiled = tf.tile(e2_emb_expanded, [1, num_entities, 1, 1])  # [num_sents, num_ents, num_ents, emb]
  

  similarity_emb = e1_emb_expanded * e2_emb_expanded  # [num_sents, num_ents, num_ents, emb]

  # if config['add_ner_emb']:
  #   pair_emb_list = [ner1_emb_tiled, ner2_emb_tiled, e1_emb_tiled, e2_emb_tiled, similarity_emb]
  # else:
  pair_emb_list = [e1_emb_tiled, e2_emb_tiled, similarity_emb]
  # pair_emb_list = [e1_emb_tiled, e2_emb_tiled, similarity_emb]
  pair_emb = tf.concat(pair_emb_list, 3)  # [num_sentences, num_ents, num_ents, emb]
  pair_emb_size = util.shape(pair_emb, 3)
  flat_pair_emb = tf.reshape(pair_emb, [num_sentences * num_entities * num_entities, pair_emb_size])

  flat_rel_scores = get_unary_scores(flat_pair_emb, config, dropout, num_labels,
      "relation_scores")  # [num_sentences * num_ents * num_ents, 1]
  rel_scores = tf.reshape(flat_rel_scores, [num_sentences, num_entities, num_entities, num_labels])
  rel_scores += tf.expand_dims(tf.expand_dims(entity_scores, 2), 3) + tf.expand_dims(
      tf.expand_dims(entity_scores, 1), 3)  # [num_sentences, ents, max_num_ents, num_labels-1]

  return rel_scores  # [num_sentences, num_entities, num_entities, num_labels]
Exemplo n.º 3
0
 def _build_graph_cum(
         self, e, k, mf,
         mb):  # cumulative edge graph based on all previous words
     # cumulate
     hidden_size = util.shape(e, 2)  # e: [batch_size, seq_len, hidden_size]
     seq_len = util.shape(e, 1)
     batch_size = util.shape(e, 0)
     cum = tf.tile(tf.expand_dims(e, 2),
                   [1, 1, seq_len, 1
                    ])  # [batch_size, seq_len, seq_len(new), hidden_size]
     # mask_cum = tf.expand_dims(tf.range(seq_len) - tf.transpose(tf.range(seq_len)), 0) # [1, seq_len, seq_len]
     # mask_cum = tf.cast(mask_cum >= 0, tf.float32)
     # tril half of the matrix
     mask_cum_tril = tf.expand_dims(
         tf.range(seq_len) - tf.transpose(tf.range(seq_len)),
         0)  # [1, seq_len, seq_len]
     mask_cum_tril_f = tf.cast(mask_cum_tril >= 0, tf.float32)
     mask_cum_tril_b = tf.cast(mask_cum_tril < 0, tf.float32)
     mask_cum_tril_f = tf.tile(tf.expand_dims(mask_cum_tril_f, 2),
                               [1, 1, hidden_size])
     mask_cum_tril_f = tf.tile(tf.expand_dims(mask_cum_tril_f, 0),
                               [batch_size, 1, 1, 1])
     mask_cum_tril_b = tf.tile(tf.expand_dims(mask_cum_tril_b, 2),
                               [1, 1, hidden_size])
     mask_cum_tril_b = tf.tile(tf.expand_dims(mask_cum_tril_b, 0),
                               [batch_size, 1, 1, 1])
     # [batch_size, seq_len, seq_len, hidden_size]
     mask_cum_f = tf.contrib.layers.fully_connected(
         inputs=mf, num_outputs=self.n_outputs, activation_fn=tf.nn.relu)
     # [batch_size, seq_len, hidden_size]
     mask_cum_b = tf.contrib.layers.fully_connected(
         inputs=mb, num_outputs=self.n_outputs, activation_fn=tf.nn.relu)
     # mask_cum = m
     # [batch_size, seq_len, hidden_size]
     # mask_cum = tf.tile(tf.expand_dims(mask_cum, 2), [1, 1, seq_len, 1])
     # mask_cum = tf.matmul(m, tf.transpose(mask_cum, [0, 2, 1])) # [batch_size, seq_len, seq_len]
     mask_cum_f = tf.matmul(mask_cum_f, tf.transpose(
         mask_cum_f,
         [0, 2, 1]))  # [batch_size, seq_len, seq_len, hidden_size]
     mask_cum_f = tf.tile(tf.expand_dims(mask_cum_f, 3),
                          [1, 1, 1, hidden_size])
     mask_cum_b = tf.matmul(mask_cum_b, tf.transpose(mask_cum_b, [0, 2, 1]))
     mask_cum_b = tf.tile(tf.expand_dims(mask_cum_b, 3),
                          [1, 1, 1, hidden_size])
     mask_cum_f = mask_cum_f * mask_cum_tril_f
     mask_cum_b = mask_cum_b * mask_cum_tril_b
     # [batch_size, seq_len, seq_len, hidden_size]
     # mask_cum = mask_cum * mask_cum_tril
     # mask_cum = tf.tile(tf.expand_dims(mask_cum, 2), [1, 1, hidden_size])
     # mask_cum = tf.tile(tf.expand_dims(mask_cum, 0), [batch_size, 1, 1, 1]) # [batch_size, seq_len, seq_len, hidden_size]
     # cum = cum * mask_cum
     cum_f = cum * mask_cum_f
     cum_b = cum * mask_cum_b
     cum = cum_f + cum_b
     cum = tf.reduce_sum(cum, axis=2)
     # cum = cum / tf.reduce_sum(mask_cum, axis = 2)
     # cum = tf.transpose(cum, perm = [0, 2, 1])
     # e * k
     # return tf.matmul(e, tf.transpose(k, [0, 2, 1]))
     return tf.matmul(cum, tf.transpose(k, [0, 2, 1]))
Exemplo n.º 4
0
    def get_pair_embeddings(self, mention_emb, antecedents, antecedent_emb):
        k = util.shape(mention_emb, 0)
        c = util.shape(antecedents, 1)

        feature_emb_list = []
        antecedent_offsets = tf.tile(tf.expand_dims(tf.range(c) + 1, 0),
                                     [k, 1])  # [k, c]

        if self.config["use_features"]:
            antecedent_distance_buckets = self.bucket_distance(
                antecedent_offsets)  # [k, c]
            antecedent_distance_emb = tf.gather(
                tf.get_variable("antecedent_distance_emb",
                                [10, self.config["feature_size"]]),
                antecedent_distance_buckets)  # [k, c]
            feature_emb_list.append(antecedent_distance_emb)

        feature_emb = tf.concat(feature_emb_list, 2)  # [k, c, emb]
        feature_emb = tf.nn.dropout(feature_emb, self.dropout)  # [k, c, emb]

        target_emb = tf.expand_dims(mention_emb, 1)  # [k, 1, emb]
        similarity_emb = antecedent_emb * target_emb  # [k, c, emb]
        target_emb = tf.tile(target_emb, [1, c, 1])  # [k, c, emb]

        pair_emb = tf.concat(
            [target_emb, antecedent_emb, similarity_emb, feature_emb],
            2)  # [k, c, emb]

        return pair_emb
Exemplo n.º 5
0
def get_span_task_labels(arg_starts, arg_ends, labels, max_sentence_length):
    """Get dense labels for NER/Constituents (unary span prediction tasks).
  """
    num_sentences = util.shape(arg_starts, 0)
    max_num_args = util.shape(arg_starts, 1)
    sentence_indices = tf.tile(
        tf.expand_dims(tf.range(num_sentences),
                       1), [1, max_num_args])  # [num_sentences, max_num_args]
    pred_indices = tf.concat([
        tf.expand_dims(sentence_indices, 2),
        tf.expand_dims(arg_starts, 2),
        tf.expand_dims(arg_ends, 2)
    ],
                             axis=2)  # [num_sentences, max_num_args, 3]

    dense_ner_labels = get_dense_span_labels(
        labels["ner_starts"], labels["ner_ends"], labels["ner_labels"],
        labels["ner_len"],
        max_sentence_length)  # [num_sentences, max_sent_len, max_sent_len]
    dense_coref_labels = get_dense_span_labels(
        labels["coref_starts"], labels["coref_ends"],
        labels["coref_cluster_ids"], labels["coref_len"],
        max_sentence_length)  # [num_sentences, max_sent_len, max_sent_len]

    ner_labels = tf.gather_nd(
        params=dense_ner_labels,
        indices=pred_indices)  # [num_sentences, max_num_args]
    coref_cluster_ids = tf.gather_nd(
        params=dense_coref_labels,
        indices=pred_indices)  # [num_sentences, max_num_args]
    return ner_labels, coref_cluster_ids
Exemplo n.º 6
0
def get_rel_softmax_loss(rel_scores, rel_labels, num_predicted_entities):
    """Softmax loss with 2-D masking.
  Args:
    rel_scores: [num_sentences, max_num_entities, max_num_entities, num_labels]
    rel_labels: [num_sentences, max_num_entities, max_num_entities]
    num_predicted_entities: [num_sentences]
  """
    max_num_entities = util.shape(rel_scores, 1)
    num_labels = util.shape(rel_scores, 3)
    entities_mask = tf.sequence_mask(
        num_predicted_entities,
        max_num_entities)  # [num_sentences, max_num_entities]

    rel_loss_mask = tf.logical_and(
        tf.expand_dims(entities_mask,
                       2),  # [num_sentences, max_num_entities, 1]
        tf.expand_dims(entities_mask,
                       1)  # [num_sentences, 1, max_num_entities]
    )  # [num_sentences, max_num_entities, max_num_entities]
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=tf.reshape(rel_labels, [-1]),
        logits=tf.reshape(rel_scores, [-1, num_labels]),
        name="srl_softmax_loss"
    )  # [num_sentences * max_num_args * max_num_preds]
    loss = tf.boolean_mask(loss, tf.reshape(rel_loss_mask, [-1]))
    loss.set_shape([None])
    loss = tf.reduce_sum(loss)
    return loss
Exemplo n.º 7
0
def get_srl_softmax_loss(srl_scores, srl_labels, num_predicted_args,
                         num_predicted_preds):
    """Softmax loss with 2-D masking (for SRL).
  Args:
    srl_scores: [num_sentences, max_num_args, max_num_preds, num_labels]
    srl_labels: [num_sentences, max_num_args, max_num_preds]
    num_predicted_args: [num_sentences]
    num_predicted_preds: [num_sentences]
  """
    max_num_args = util.shape(srl_scores, 1)
    max_num_preds = util.shape(srl_scores, 2)
    num_labels = util.shape(srl_scores, 3)
    args_mask = tf.sequence_mask(num_predicted_args,
                                 max_num_args)  # [num_sentences, max_num_args]
    preds_mask = tf.sequence_mask(
        num_predicted_preds, max_num_preds)  # [num_sentences, max_num_preds]
    srl_loss_mask = tf.logical_and(
        tf.expand_dims(args_mask, 2),  # [num_sentences, max_num_args, 1]
        tf.expand_dims(preds_mask, 1)  # [num_sentences, 1, max_num_preds]
    )  # [num_sentences, max_num_args, max_num_preds]
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=tf.reshape(srl_labels, [-1]),
        logits=tf.reshape(srl_scores, [-1, num_labels]),
        name="srl_softmax_loss"
    )  # [num_sentences * max_num_args * max_num_preds]
    loss = tf.boolean_mask(loss, tf.reshape(srl_loss_mask, [-1]))
    loss.set_shape([None])
    loss = tf.reduce_sum(loss)
    return loss
Exemplo n.º 8
0
def get_rel_softmax_loss(rel_scores, rel_labels, num_predicted_entities,
                         config):
    """Softmax loss with 2-D masking.
  Args:
    rel_scores: [num_sentences, max_num_entities, max_num_entities, num_labels]
    rel_labels: [num_sentences, max_num_entities, max_num_entities]
    num_predicted_entities: [num_sentences]
  """
    max_num_entities = util.shape(rel_scores, 1)
    num_labels = util.shape(rel_scores, 3)
    entities_mask = tf.sequence_mask(
        num_predicted_entities,
        max_num_entities)  # [num_sentences, max_num_entities]
    randp = config['ns_randp']
    print "Negative sample rate: " + str(randp)
    negative_sample_mask = tf.py_func(
        get_negative_sample_mask_func,
        [rel_labels, rel_scores, num_predicted_entities, randp], tf.bool)
    rel_loss_mask = tf.logical_and(
        tf.expand_dims(entities_mask,
                       2),  # [num_sentences, max_num_entities, 1]
        tf.expand_dims(entities_mask,
                       1)  # [num_sentences, 1, max_num_entities]
    )  # [num_sentences, max_num_entities, max_num_entities]
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=tf.reshape(rel_labels, [-1]),
        logits=tf.reshape(rel_scores, [-1, num_labels]),
        name="srl_softmax_loss"
    )  # [num_sentences * max_num_args * max_num_preds]
    # loss = tf.boolean_mask(loss, tf.reshape(rel_loss_mask, [-1]))
    loss = tf.boolean_mask(loss, tf.reshape(negative_sample_mask, [-1]))
    loss.set_shape([None])
    loss = tf.reduce_sum(loss)
    return loss
Exemplo n.º 9
0
def get_dense_span_labels(span_starts, span_ends, span_labels, num_spans, max_sentence_length, span_parents=None):
  """Utility function to get dense span or span-head labels.
  Args:
    span_starts: [num_sentences, max_num_spans]
    span_ends: [num_sentences, max_num_spans]
    span_labels: [num_sentences, max_num_spans]
    num_spans: [num_sentences,]
    max_sentence_length:
    span_parents: [num_sentences, max_num_spans]. Predicates in SRL.
  """
  num_sentences = util.shape(span_starts, 0)
  max_num_spans = util.shape(span_starts, 1)
  # For padded spans, we have starts = 1, and ends = 0, so they don't collide with any existing spans.
  span_starts += (1 - tf.sequence_mask(num_spans, dtype=tf.int32))  # [num_sentences, max_num_spans]
  sentence_indices = tf.tile(
      tf.expand_dims(tf.range(num_sentences), 1),
      [1, max_num_spans])  # [num_sentences, max_num_spans]
  sparse_indices = tf.concat([
      tf.expand_dims(sentence_indices, 2),
      tf.expand_dims(span_starts, 2),
      tf.expand_dims(span_ends, 2)], axis=2)  # [num_sentences, max_num_spans, 3]
  if span_parents is not None:
    sparse_indices = tf.concat([
      sparse_indices, tf.expand_dims(span_parents, 2)], axis=2)  # [num_sentenes, max_num_spans, 4]

  rank = 3 if (span_parents is None) else 4
  # (sent_id, span_start, span_end) -> span_label
  dense_labels = tf.sparse_to_dense(
      sparse_indices = tf.reshape(sparse_indices, [num_sentences * max_num_spans, rank]),
      output_shape = [num_sentences] + [max_sentence_length] * (rank - 1),
      sparse_values = tf.reshape(span_labels, [-1]),
      default_value = 0,
      validate_indices = False)  # [num_sentences, max_sent_len, max_sent_len]
  return dense_labels
Exemplo n.º 10
0
def get_batch_topk(candidate_starts, candidate_ends, candidate_scores, topk_ratio, text_len,
                   max_sentence_length, sort_spans=False, enforce_non_crossing=True):
  """
  Args:
    candidate_starts: [num_sentences, max_num_candidates]
    candidate_mask: [num_sentences, max_num_candidates]
    topk_ratio: A float number.
    text_len: [num_sentences,]
    max_sentence_length:
    enforce_non_crossing: Use regular top-k op if set to False.
 """
  num_sentences = util.shape(candidate_starts, 0)
  max_num_candidates = util.shape(candidate_starts, 1)

  topk = tf.maximum(tf.to_int32(tf.floor(tf.to_float(text_len) * topk_ratio)),
                    tf.ones([num_sentences,], dtype=tf.int32))  # [num_sentences]

  predicted_indices = srl_ops.extract_spans(
      candidate_scores, candidate_starts, candidate_ends, topk, max_sentence_length,
      sort_spans, enforce_non_crossing)  # [num_sentences, max_num_predictions]
  predicted_indices.set_shape([None, None])

  predicted_starts = batch_gather(candidate_starts, predicted_indices)  # [num_sentences, max_num_predictions]
  predicted_ends = batch_gather(candidate_ends, predicted_indices)  # [num_sentences, max_num_predictions]
  predicted_scores = batch_gather(candidate_scores, predicted_indices)  # [num_sentences, max_num_predictions]

  return predicted_starts, predicted_ends, predicted_scores, topk, predicted_indices
Exemplo n.º 11
0
  def get_slow_antecedent_scores(self, top_span_emb, top_antecedents, top_antecedent_emb, top_antecedent_offsets, top_span_speaker_ids, genre_emb):
    k = util.shape(top_span_emb, 0)
    c = util.shape(top_antecedents, 1)

    feature_emb_list = []

    if self.config["use_metadata"]:
      top_antecedent_speaker_ids = tf.gather(top_span_speaker_ids, top_antecedents) # [k, c]
      same_speaker = tf.equal(tf.expand_dims(top_span_speaker_ids, 1), top_antecedent_speaker_ids) # [k, c]
      speaker_pair_emb = tf.gather(tf.get_variable("same_speaker_emb", [2, self.config["feature_size"]]), tf.to_int32(same_speaker)) # [k, c, emb]
      feature_emb_list.append(speaker_pair_emb)

      tiled_genre_emb = tf.tile(tf.expand_dims(tf.expand_dims(genre_emb, 0), 0), [k, c, 1]) # [k, c, emb]
      feature_emb_list.append(tiled_genre_emb)

    if self.config["use_features"]:
      antecedent_distance_buckets = self.bucket_distance(top_antecedent_offsets) # [k, c]
      antecedent_distance_emb = tf.gather(tf.get_variable("antecedent_distance_emb", [10, self.config["feature_size"]]), antecedent_distance_buckets) # [k, c]
      feature_emb_list.append(antecedent_distance_emb)

    feature_emb = tf.concat(feature_emb_list, 2) # [k, c, emb]
    feature_emb = tf.nn.dropout(feature_emb, self.dropout) # [k, c, emb]

    target_emb = tf.expand_dims(top_span_emb, 1) # [k, 1, emb]
    similarity_emb = top_antecedent_emb * target_emb # [k, c, emb]
    target_emb = tf.tile(target_emb, [1, c, 1]) # [k, c, emb]

    pair_emb = tf.concat([target_emb, top_antecedent_emb, similarity_emb, feature_emb], 2) # [k, c, emb]

    with tf.variable_scope("slow_antecedent_scores"):
      slow_antecedent_scores = util.ffnn(pair_emb, self.config["ffnn_depth"], self.config["ffnn_size"], 1, self.dropout) # [k, c, 1]
    slow_antecedent_scores = tf.squeeze(slow_antecedent_scores, 2) # [k, c]
    return slow_antecedent_scores # [k, c]
Exemplo n.º 12
0
def get_span_emb(head_emb, context_outputs, span_starts, span_ends, config, dropout):
  """Compute span representation shared across tasks.
  Args:
    head_emb: Tensor of [num_words, emb]
    context_outputs: Tensor of [num_words, emb]
    span_starts: [num_spans]
    span_ends: [num_spans]
  """
  text_length = util.shape(context_outputs, 0)
  num_spans = util.shape(span_starts, 0)

  max_arg_width = config["max_arg_width"]
  num_heads = config["num_attention_heads"]

  span_start_emb = tf.gather(context_outputs, span_starts)  # [num_words, emb]
  span_end_emb = tf.gather(context_outputs, span_ends)  # [num_words, emb]

  if max_arg_width > 1:
    span_emb_list = [span_start_emb, span_end_emb]
  else:
    span_emb_list = [span_start_emb]

  # span_emb_list = [span_start_emb, span_end_emb]

  span_width = 1 + span_ends - span_starts # [num_spans]
  
  if config["use_features"] and max_arg_width > 1: #
    span_width_index = span_width - 1  # [num_spans]
    span_width_emb = tf.gather(
        tf.get_variable("span_width_embeddings", [max_arg_width, config["feature_size"]]),
        span_width_index)  # [num_spans, emb]
    span_width_emb = tf.nn.dropout(span_width_emb, dropout)
    span_emb_list.append(span_width_emb)

  head_scores = None
  span_text_emb = None
  span_indices = None
  span_indices_log_mask = None

  if config["model_heads"]: # and max_arg_width > 1
    if max_arg_width > 1:
      span_indices = tf.minimum(
          tf.expand_dims(tf.range(max_arg_width), 0) + tf.expand_dims(span_starts, 1),
          text_length - 1)  # [num_spans, max_span_width]
      span_text_emb = tf.gather(head_emb, span_indices)  # [num_spans, max_arg_width, emb]
      span_indices_log_mask = tf.log(
          tf.sequence_mask(span_width, max_arg_width, dtype=tf.float32)) # [num_spans, max_arg_width]
      with tf.variable_scope("head_scores"):
        head_scores = util.projection(context_outputs, num_heads)  # [num_words, num_heads]
      span_attention = tf.nn.softmax(
        tf.gather(head_scores, span_indices) + tf.expand_dims(span_indices_log_mask, 2),
        dim=1)  # [num_spans, max_arg_width, num_heads]
      span_head_emb = tf.reduce_sum(span_attention * span_text_emb, 1)  # [num_spans, emb]
    else:
      span_head_emb = tf.gather(head_emb, span_starts)
    span_emb_list.append(span_head_emb)

  span_emb = tf.concat(span_emb_list, 1) # [num_spans, emb]

  return span_emb, head_scores, span_text_emb, span_indices, span_indices_log_mask
Exemplo n.º 13
0
def get_antecedent_scores(top_span_emb, top_span_mention_scores, antecedents, config, dropout, top_fast_antecedent_scores, top_antecedent_offsets):
  k = util.shape(top_span_emb, 0)
  max_antecedents = util.shape(antecedents, 1)
  feature_emb_list = []


  if config["use_features"]:
    # target_indices = tf.range(k)  # [k]
    # antecedent_distance = tf.expand_dims(target_indices, 1) - antecedents  # [k, max_ant]
    # antecedent_distance_buckets = bucket_distance(antecedent_distance)  # [k, max_ant]
    antecedent_distance_buckets = bucket_distance(top_antecedent_offsets)
    with tf.variable_scope("features"):
      antecedent_distance_emb = tf.gather(
          tf.get_variable("antecedent_distance_emb", [10, config["feature_size"]]),
          antecedent_distance_buckets)  # [k, max_ant]
    feature_emb_list.append(antecedent_distance_emb)

  feature_emb = tf.concat(feature_emb_list, 2)  # [k, max_ant, emb]
  feature_emb = tf.nn.dropout(feature_emb, dropout)  # [k, max_ant, emb]

  antecedent_emb = tf.gather(top_span_emb, antecedents)  # [k, max_ant, emb]
  target_emb = tf.expand_dims(top_span_emb, 1)  # [k, 1, emb]
  similarity_emb = antecedent_emb * target_emb  # [k, max_ant, emb]
  target_emb = tf.tile(target_emb, [1, max_antecedents, 1])  # [k, max_ant, emb]
  pair_emb = tf.concat([target_emb, antecedent_emb, similarity_emb, feature_emb], 2)  # [k, max_ant, emb]
  with tf.variable_scope("antecedent_scores"):
    antecedent_scores = util.ffnn(pair_emb, config["ffnn_depth"], config["ffnn_size"], 1,
                                  dropout)  # [k, max_ant, 1]
    antecedent_scores = tf.squeeze(antecedent_scores, 2)  # [k, max_ant]
  # antecedent_scores += tf.expand_dims(top_span_mention_scores, 1) + tf.gather(
  #     top_span_mention_scores, antecedents)  # [k, max_ant]
  antecedent_scores += top_fast_antecedent_scores
  return antecedent_scores, antecedent_emb, pair_emb  # [k, max_ant]
Exemplo n.º 14
0
    def get_feature_attention_score(self, tmp_feature_emb,
                                    tmp_candidate_embedding, tmp_name):
        k = util.shape(tmp_feature_emb, 0)  # [k, c,
        c = util.shape(tmp_feature_emb, 1)
        tmp_feature_size = util.shape(tmp_feature_emb, 2)
        tmp_emb_size = util.shape(tmp_candidate_embedding, 2)
        overall_emb = tf.concat([tmp_candidate_embedding, tmp_feature_emb],
                                2)  # [k, c, feature_size+embedding_size]

        repeated_emb = tf.tile(
            tf.expand_dims(overall_emb, 1),
            [1, c, 1, 1])  # [k, c, c, feature_size+embedding_size]
        tiled_emb = tf.tile(
            tf.expand_dims(overall_emb, 2),
            [1, 1, c, 1])  # [k, c, c, feature_size+embedding_size]

        final_feature = tf.concat(
            [repeated_emb, tiled_emb, repeated_emb * tiled_emb],
            3)  # [k, c, c, (feature_size+embedding_size)*3]
        final_feature = tf.reshape(
            final_feature, [k, c * c, (tmp_feature_size + tmp_emb_size) * 3])
        with tf.variable_scope(tmp_name):
            feature_attention_scores = util.ffnn(final_feature,
                                                 self.config["ffnn_depth"],
                                                 self.config["ffnn_size"], 1,
                                                 self.dropout)  # [k, c*c, 1]
        feature_attention_scores = tf.reshape(feature_attention_scores,
                                              [k, c, c, 1])
        return feature_attention_scores
Exemplo n.º 15
0
  def get_antecedent_scores(self, mention_emb, mention_scores, antecedents, antecedents_len, mention_starts, mention_ends, mention_speaker_ids, genre_emb, mention_ner_ids):
    num_mentions = util.shape(mention_emb, 0)
    max_antecedents = util.shape(antecedents, 1)

    feature_emb_list = []

    if self.config["use_metadata"]:
      antecedent_speaker_ids = tf.gather(mention_speaker_ids, antecedents) # [num_mentions, max_ant]
      same_speaker = tf.equal(tf.expand_dims(mention_speaker_ids, 1), antecedent_speaker_ids) # [num_mentions, max_ant]
      speaker_pair_emb = tf.gather(tf.get_variable("same_speaker_emb", [2, self.config["feature_size"]]), tf.to_int32(same_speaker)) # [num_mentions, max_ant, emb]
      feature_emb_list.append(speaker_pair_emb)

      # tile is duplicating data [a b c d] --> [a b c d a b c d]
      tiled_genre_emb = tf.tile(tf.expand_dims(tf.expand_dims(genre_emb, 0), 0), [num_mentions, max_antecedents, 1]) # [num_mentions, max_ant, emb]
      feature_emb_list.append(tiled_genre_emb)

    if self.config["use_features"]:
      target_indices = tf.range(num_mentions) # [num_mentions]
      mention_distance = tf.expand_dims(target_indices, 1) - antecedents # [num_mentions, max_ant]
      mention_distance_bins = coref_ops.distance_bins(mention_distance) # [num_mentions, max_ant]
      mention_distance_bins.set_shape([None, None])
      mention_distance_emb = tf.gather(tf.get_variable("mention_distance_emb", [10, self.config["feature_size"]]), mention_distance_bins) # [num_mentions, max_ant]
      feature_emb_list.append(mention_distance_emb)

    if self.config["use_ner_phi"]:
      antecedent_ner_ids = tf.gather(mention_ner_ids, antecedents)
      same_ner = tf.equal(tf.expand_dims(mention_ner_ids, 1), antecedent_ner_ids)
      ner_pair_emb = tf.gather(tf.get_variable("same_ner_emb", [2, self.config["feature_size"]]), tf.to_int32(same_ner))
      feature_emb_list.append(ner_pair_emb)

    # phi(i, j)
    feature_emb = tf.concat(feature_emb_list, 2) # [num_mentions, max_ant, emb]
    feature_emb = tf.nn.dropout(feature_emb, self.dropout) # [num_mentions, max_ant, emb]

    # g_i
    antecedent_emb = tf.gather(mention_emb, antecedents) # [num_mentions, max_ant, emb]
    
    # g_j 
    target_emb_tiled = tf.tile(tf.expand_dims(mention_emb, 1), [1, max_antecedents, 1]) # [num_mentions, max_ant, emb]
    
    # g_i . g_j
    similarity_emb = antecedent_emb * target_emb_tiled # [num_mentions, max_ant, emb]

    # [g_i, g_j, g_i . g_j, phi(i, j)]
    pair_emb = tf.concat([target_emb_tiled, antecedent_emb, similarity_emb, feature_emb], 2) # [num_mentions, max_ant, emb]

    with tf.variable_scope("iteration"):
      with tf.variable_scope("antecedent_scoring"):
        antecedent_scores = util.ffnn(pair_emb, self.config["ffnn_depth"], self.config["ffnn_size"], 1, self.dropout) # [num_mentions, max_ant, 1]
    antecedent_scores = tf.squeeze(antecedent_scores, 2) # [num_mentions, max_ant]

    antecedent_mask = tf.log(tf.sequence_mask(antecedents_len, max_antecedents, dtype=tf.float32)) # [num_mentions, max_ant]
    antecedent_scores += antecedent_mask # [num_mentions, max_ant]

    antecedent_scores += tf.expand_dims(mention_scores, 1) + tf.gather(mention_scores, antecedents) # [num_mentions, max_ant]
    antecedent_scores = tf.concat([tf.zeros([util.shape(mention_scores, 0), 1]), antecedent_scores], 1) # [num_mentions, max_ant + 1]
    return antecedent_scores, pair_emb # [num_mentions, max_ant + 1]
Exemplo n.º 16
0
 def get_masked_mention_word_scores(self, encoded_doc, span_starts, span_ends):
     num_words = util.shape(encoded_doc, 0) # T
     num_c = util.shape(span_starts, 0) # NC
     doc_range = tf.tile(tf.expand_dims(tf.range(0, num_words), 0), [num_c, 1]) # [K, T]
     mention_mask = tf.logical_and(doc_range >= tf.expand_dims(span_starts, 1), doc_range <= tf.expand_dims(span_ends, 1)) #[K, T]
     with tf.variable_scope("mention_word_attn"):
       word_attn = tf.squeeze(util.projection(encoded_doc, 1, initializer=tf.truncated_normal_initializer(stddev=0.02)), 1)
     mention_word_attn = tf.nn.softmax(tf.log(tf.to_float(mention_mask)) + tf.expand_dims(word_attn, 0))
     return mention_word_attn
Exemplo n.º 17
0
def get_srl_scores(arg_emb, pred_emb, arg_scores, pred_scores, num_labels, config, dropout):
  num_sentences = util.shape(arg_emb, 0)
  num_args = util.shape(arg_emb, 1)
  num_preds = util.shape(pred_emb, 1)

  
  arg_emb_expanded = tf.expand_dims(arg_emb, 2)  # [num_sents, num_args, 1, emb]

  pred_emb_expanded = tf.expand_dims(pred_emb, 1)  # [num_sents, 1, num_preds, emb] 

  arg_emb_tiled = tf.tile(arg_emb_expanded, [1, 1, num_preds, 1])  # [num_sentences, num_args, num_preds, emb]

  pred_emb_tiled = tf.tile(pred_emb_expanded, [1, num_args, 1, 1])  # [num_sents, num_args, num_preds, emb]

  pair_emb_list = [arg_emb_tiled, pred_emb_tiled]

  pair_emb = tf.concat(pair_emb_list, 3)  # [num_sentences, num_args, num_preds, emb]

  pair_emb_size = util.shape(pair_emb, 3)

  flat_pair_emb = tf.reshape(pair_emb, [num_sentences * num_args * num_preds, pair_emb_size])

  flat_srl_scores = get_unary_scores(flat_pair_emb, config, dropout, num_labels - 1,
      "predicate_argument_scores")  # [num_sentences * num_args * num_predicates, num_labels - 1]
  
  srl_scores = tf.reshape(flat_srl_scores, [num_sentences, num_args, num_preds, num_labels - 1])

    # if config['use_biaffine_mlp']:
    #   # for 
    #   with tf.variable_scope('pred_mlp'):
    #     pred_emb = tf.nn.relu(util.ffnn(pred_emb, 0, -1, config['biaffine_mlp_size'], config['biaffine_mlp_dropout']))
      
    #   with tf.variable_scope('arg_mlp'):
    #     arg_emb = tf.nn.relu(util.ffnn(arg_emb, 0, -1, config['biaffine_mlp_size'], config['biaffine_mlp_dropout']))

    # srl_scores = biaffine(arg_emb, pred_emb, True, True, num_labels - 1) # [num_sentences, num_args, num_preds, num_labels - 1]

    # srl_scores = biaffine(arg_emb, pred_emb, num_labels - 1, config, dropout) # [num_sentences, num_args, num_preds, num_labels - 1]
  
  if config['use_biaffine']:
    with tf.name_scope("biaffine"):
      bw = tf.Variable(initial_value=[0.5,0.5], name="bw")
      bw_norm = tf.nn.softmax(bw)
      bilinear_srl_scores = bilinear(pred_emb, arg_emb, True, True, num_labels - 1)
      bilinear_srl_scores = tf.transpose(bilinear_srl_scores, [0, 2, 1, 3])
      srl_scores = bw_norm[0] * srl_scores + bw_norm[1] * bilinear_srl_scores


  srl_scores += tf.expand_dims(tf.expand_dims(arg_scores, 2), 3) + tf.expand_dims(
      tf.expand_dims(pred_scores, 1), 3)  # [num_sentences, 1, max_num_preds, num_labels-1]
  
  dummy_scores = tf.zeros([num_sentences, num_args, num_preds, 1], tf.float32)

  srl_scores = tf.concat([dummy_scores, srl_scores], 3)  # [num_sentences, max_num_args, max_num_preds, num_labels] 

  return srl_scores  # [num_sentences, num_args, num_predicates, num_labels]
Exemplo n.º 18
0
def get_rel_scores(entity_emb, entity_scores, num_labels, config, dropout, num_predicted_entities):
  num_sentences = util.shape(entity_emb, 0)
  num_entities = util.shape(entity_emb, 1)
  entities_mask = tf.sequence_mask(num_predicted_entities, num_entities) #[num_sentences, num_entities]
  flat_entities_mask = tf.reshape(entities_mask, [-1]) 
  rel_mask = tf.logical_and(tf.expand_dims(entities_mask, 2),  # [num_sentences, max_num_entities, 1]
                                            tf.expand_dims(entities_mask, 1)  # [num_sentences, 1, max_num_entities]
  )
  e1_emb_expanded = tf.expand_dims(entity_emb, 2)  # [num_sents, num_ents, 1, emb]
  e2_emb_expanded = tf.expand_dims(entity_emb, 1)  # [num_sents, 1, num_ents, emb]
  e1_emb_tiled = tf.tile(e1_emb_expanded, [1, 1, num_entities, 1])  # [num_sents, num_ents, num_ents, emb]
  e2_emb_tiled = tf.tile(e2_emb_expanded, [1, num_entities, 1, 1])  # [num_sents, num_ents, num_ents, emb]
  

  similarity_emb = e1_emb_expanded * e2_emb_expanded  # [num_sents, num_ents, num_ents, emb]

  pair_emb_list = [e1_emb_tiled, e2_emb_tiled, similarity_emb]

  pair_emb = tf.concat(pair_emb_list, 3)  # [num_sentences, num_ents, num_ents, emb]
  pair_emb_size = util.shape(pair_emb, 3)
  flat_pair_emb = tf.reshape(pair_emb, [num_sentences * num_entities * num_entities, pair_emb_size])

  flat_rel_scores = get_unary_scores(flat_pair_emb, config, dropout, num_labels - 1,
      "relation_scores")  # [num_sentences * num_ents * num_ents, num_labels-1]
  rel_scores = tf.reshape(flat_rel_scores, [num_sentences, num_entities, num_entities, num_labels - 1])
  rel_scores += tf.expand_dims(tf.expand_dims(entity_scores, 2), 3) + tf.expand_dims(
      tf.expand_dims(entity_scores, 1), 3)  # [num_sentences, ents, max_num_ents, num_labels-1]
  if config['rel_prop']:
    flat_rel_scores = tf.reshape(rel_scores, [num_sentences * num_entities* num_entities, num_labels - 1])
    with tf.variable_scope("rel_W"):
      entity_emb_size = util.shape(entity_emb, -1)
      relation_transition = util.projection(tf.nn.relu(flat_rel_scores), entity_emb_size) #f(V)A_R in Eq. 3
      e2_emb_tiled = tf.reshape(e2_emb_tiled, [num_sentences * num_entities * num_entities, entity_emb_size])
      rel_mask = tf.reshape(rel_mask, [-1])
      tranformed_embeddings = tf.multiply(tf.transpose(relation_transition * e2_emb_tiled), tf.to_float(rel_mask)) #[entity_emb_size, num_sents * num_ents * num_ents]
      tranformed_embeddings = tf.transpose(tranformed_embeddings) # [entity_emb_size, num_sents * num_ents * num_ents]
      tranformed_embeddings = tf.reshape(tranformed_embeddings, [num_sentences, num_entities, num_entities, entity_emb_size]) #[num_sents, num_ents, num_ents, entity_emb_size] 
      tranformed_embeddings = tf.reduce_sum(tranformed_embeddings, 2) #[num_sents, num_ents, entity_emb_size]
      tranformed_embeddings = tf.reshape(tranformed_embeddings, [num_sentences * num_entities, entity_emb_size])
      entity_emb = tf.reshape(entity_emb, [num_sentences * num_entities, entity_emb_size]) 
      with tf.variable_scope("f"):
        f = tf.sigmoid(util.projection(tf.concat([tranformed_embeddings, entity_emb], 1), entity_emb_size)) # [num_sents * num_ents, entity_emb_size]
        entity_emb = f * tranformed_embeddings + (1 - f) * entity_emb # [num_sents * num_ents, entity_emb_size]
      entity_emb = tf.reshape(entity_emb, [num_sentences, num_entities, entity_emb_size])
      
      
  dummy_scores = tf.zeros([num_sentences, num_entities, num_entities, 1], tf.float32)
  rel_scores = tf.concat([dummy_scores, rel_scores], 3)  # [num_sentences, max_num_ents, max_num_ents, num_labels]
  if config['rel_prop']:
    return rel_scores, entity_emb, flat_entities_mask
  else:
    return rel_scores  # [num_sentences, num_entities, num_entities, num_labels]
Exemplo n.º 19
0
    def get_knowledge_score(self, candidate_NP_embeddings, number_features, plurality_features, candidate_mask):
        k = util.shape(number_features, 0)
        c = util.shape(number_features, 1)

        column_mask = tf.tile(tf.expand_dims(candidate_mask, 1), [1, c, 1])  # [k, c, c]
        row_mask = tf.tile(tf.expand_dims(candidate_mask, 2), [1, 1, c])  # [k, c, c]
        square_mask = column_mask * row_mask  # [k, c, c]

        diagonal_mask = tf.ones([k, c, c]) - tf.tile(tf.expand_dims(tf.diag(tf.ones([c])), 0), [k, 1, 1])
        # we need to find the embedding for these features
        number_emb = tf.gather(tf.get_variable("number_emb", [2, self.config["feature_size"]]),
                               number_features)  # [k, c, feature_size]
        plurality_emb = tf.gather(tf.get_variable("plurality_emb", [2, self.config["feature_size"]]),
                                  plurality_features)  # [k, c, feature_size]

        if self.config['number']:
            number_score = self.get_feature_score(number_emb, 'number_score')  # [k, c, c, 1]
        else:
            number_score = tf.zeros([k, c, c, 1])

        if self.config['plurality']:
            plurality_score = self.get_feature_score(plurality_emb, 'plurality_score')  # [k, c, c, 1]
        else:
            plurality_score = tf.zeros([k, c, c, 1])

        merged_score = tf.concat([number_score, plurality_score], 3)  # [k, c, c, 2]

        if self.config['attention']:
            if self.config['number']:
                number_attention_score = self.get_feature_attention_score(number_emb, candidate_NP_embeddings,
                                                                          'number_attention_score')
            else:
                number_attention_score = tf.ones([k, c, c, 1]) * -1000

            if self.config['plurality']:
                plurality_attention_score = self.get_feature_attention_score(plurality_emb, candidate_NP_embeddings,
                                                                             'plurality_attention_score')
            else:
                plurality_attention_score = tf.ones([k, c, c, 1]) * -1000

            merged_attention_score = tf.concat([number_attention_score, plurality_attention_score], 3)
            all_attention_scores = tf.nn.softmax(merged_attention_score, 3)  # [k, c, c, 2]
            all_scores = merged_score * all_attention_scores
        else:
            all_scores = merged_score
            all_attention_scores = tf.zeros([k, c, c, 4])
        all_scores = tf.reduce_sum(all_scores, 3)  # [k, c, c]
        all_scores = all_scores * diagonal_mask
        all_scores = all_scores * square_mask
        final_score = tf.reduce_mean(all_scores, 2)  # [k, c]

        return final_score, merged_score, all_attention_scores, diagonal_mask, square_mask
Exemplo n.º 20
0
  def get_slow_antecedent_scores(self, top_span_emb, top_antecedents, top_antecedent_emb, top_antecedent_offsets, top_span_speaker_ids, genre_emb, top_scene_emb, top_antecedent_scene_emb, top_span_genders, top_span_fpronouns):
    k = util.shape(top_span_emb, 0)
    c = util.shape(top_antecedents, 1)

    feature_emb_list = []

    if self.config["use_metadata"]:
      top_antecedent_speaker_ids = tf.gather(top_span_speaker_ids, top_antecedents) # [k, c]
      same_speaker = tf.equal(tf.expand_dims(top_span_speaker_ids, 1), top_antecedent_speaker_ids) # [k, c]
      speaker_pair_emb = tf.gather(tf.get_variable("same_speaker_emb", [2, self.config["feature_size"]]), tf.to_int32(same_speaker)) # [k, c, emb]
      feature_emb_list.append(speaker_pair_emb)

      top_antecedent_genders = tf.gather(top_span_genders, top_antecedents)
      same_gender = ((tf.expand_dims(top_span_genders,1) * top_antecedent_genders) >= 0)
      same_gender_emb = tf.gather(tf.get_variable("same_gender_emb", [2, self.config["feature_size"]]), tf.to_int32(same_gender))
      feature_emb_list.append(same_gender_emb)

      top_antecedent_fpronouns = tf.gather(top_span_fpronouns, top_antecedents) # [k, c]
      fpronoun_count = tf.add(tf.expand_dims(top_span_fpronouns, 1), top_antecedent_fpronouns) # [k, c]
      no_same_speaker = tf.to_int32(tf.logical_not(tf.equal(tf.expand_dims(top_span_speaker_ids, 1), top_antecedent_speaker_ids))) # [k, c]
      same_speaker_and_fp = (tf.add(fpronoun_count,no_same_speaker) < 3)
      same_speaker_and_fp_emb = tf.gather(tf.get_variable("same_speaker_and_fp_emb", [2, self.config["feature_size"]]), tf.to_int32(same_speaker_and_fp))
      feature_emb_list.append(same_speaker_and_fp_emb)

      #tiled_genre_emb = tf.tile(tf.expand_dims(tf.expand_dims(genre_emb, 0), 0), [k, c, 1]) # [k, c, emb]
      #feature_emb_list.append(tiled_genre_emb)

    if self.config["use_features"]:
      antecedent_distance_buckets = self.bucket_distance(top_antecedent_offsets) # [k, c]
      antecedent_distance_emb = tf.gather(tf.get_variable("antecedent_distance_emb", [10, self.config["feature_size"]]), antecedent_distance_buckets) # [k, c]
      feature_emb_list.append(antecedent_distance_emb)

    feature_emb = tf.concat(feature_emb_list, 2) # [k, c, emb]
    feature_emb = tf.nn.dropout(feature_emb, self.dropout) # [k, c, emb]

    target_emb = tf.expand_dims(top_span_emb, 1) # [k, 1, emb]
    similarity_emb = top_antecedent_emb * target_emb # [k, c, emb]
    target_emb = tf.tile(target_emb, [1, c, 1]) # [k, c, emb]

    target_scene_emb = tf.expand_dims(top_scene_emb, 1) # [k, 1, emb-scene]
    target_scene_emb = tf.tile(target_scene_emb, [1, c, 1]) # [k, c, emb]

    if (self.config['use_video']):
      pair_emb = tf.concat([target_scene_emb, top_antecedent_scene_emb, target_emb, top_antecedent_emb, similarity_emb, feature_emb], 2) # [k, c, emb]
    else:
      pair_emb = tf.concat([target_emb, top_antecedent_emb, similarity_emb, feature_emb], 2) # [k, c, emb]

    with tf.variable_scope("slow_antecedent_scores"):
      slow_antecedent_scores = util.ffnn(pair_emb, self.config["ffnn_depth"], self.config["ffnn_size"], 1, self.dropout) # [k, c, 1]
    slow_antecedent_scores = tf.squeeze(slow_antecedent_scores, 2) # [k, c]
    return slow_antecedent_scores # [k, c]
Exemplo n.º 21
0
def get_srl_labels(arg_starts, arg_ends, predicates, labels,
                   max_sentence_length):
    """
  Args:
    arg_starts: [num_sentences, max_num_args]
    arg_ends: [num_sentences, max_num_args]
    predicates: [num_sentences, max_num_predicates]
    labels: Dictionary of label tensors.
    max_sentence_length: An integer scalar.
  """
    num_sentences = util.shape(arg_starts, 0)
    max_num_args = util.shape(arg_starts, 1)
    max_num_preds = util.shape(predicates, 1)
    sentence_indices_2d = tf.tile(
        tf.expand_dims(tf.expand_dims(tf.range(num_sentences), 1), 2),
        [1, max_num_args, max_num_preds
         ])  # [num_sentences, max_num_args, max_num_preds]
    tiled_arg_starts = tf.tile(
        tf.expand_dims(arg_starts, 2),
        [1, 1, max_num_preds])  # [num_sentences, max_num_args, max_num_preds]
    tiled_arg_ends = tf.tile(
        tf.expand_dims(arg_ends, 2),
        [1, 1, max_num_preds])  # [num_sentences, max_num_args, max_num_preds]
    tiled_predicates = tf.tile(
        tf.expand_dims(predicates, 1),
        [1, max_num_args, 1])  # [num_sentences, max_num_args, max_num_preds]
    pred_indices = tf.concat(
        [
            tf.expand_dims(sentence_indices_2d, 3),
            tf.expand_dims(tiled_arg_starts, 3),
            tf.expand_dims(tiled_arg_ends, 3),
            tf.expand_dims(tiled_predicates, 3)
        ],
        axis=3)  # [num_sentences, max_num_args, max_num_preds, 4]

    dense_srl_labels = get_dense_span_labels(
        labels["arg_starts"],
        labels["arg_ends"],
        labels["arg_labels"],
        labels["srl_len"],
        max_sentence_length,
        span_parents=labels["predicates"]
    )  # [num_sentences, max_sent_len, max_sent_len, max_sent_len]
    print("dense_srl_labels: ", dense_srl_labels)
    srl_labels = tf.gather_nd(
        params=dense_srl_labels,
        indices=pred_indices)  # [num_sentences, max_num_args]
    print("srl_labels: ", srl_labels)
    return srl_labels
Exemplo n.º 22
0
 def _build_graph_cum_2(self, e, k):
     batch_size = util.shape(e, 0)
     seq_len = util.shape(e, 1)
     hidden_size = util.shape(e, 2)
     mask_cum = tf.contrib.layers.fully_connected(
         inputs=tf.concat([e, k], 2),
         num_outputs=self.n_outputs,
         activation_fn=tf.nn.relu)
     mask_cum = tf.matmul(mask_cum, tf.transpose(mask_cum, [0, 2, 1]))
     mask_cum = tf.tile(tf.expand_dims(mask_cum, 3), [1, 1, 1, hidden_size])
     cum = tf.tile(tf.expand_dims(e, 2), [1, 1, seq_len, 1])
     # [batch_size, seq_len, seq_len, hidden_size]
     cum = cum * mask_cum
     cum = tf.reduce_sum(cum, axis=2)
     return tf.matmul(cum, tf.transpose(k, [0, 2, 1]))
Exemplo n.º 23
0
def get_ner_candidates(candidate_starts, candidate_ends, candidate_scores, candidate_mask, text_len, topk_ratio):
  """
  Args:
    candidate_starts: [num_sentences, max_num_candidates]
    candidate_mask: [num_sentences, max_num_candidates]
    candidate_scores: [num_sentences, max_num_candidates, num_labels]
 """
  candidate_scores = tf.nn.softmax(candidate_scores, axis = -1)
  num_sentences, max_num_ents, num_labels = candidate_scores.shape
  num_sentences = util.shape(candidate_starts, 0)
  topk = tf.maximum(tf.to_int32(tf.floor(tf.to_float(text_len) * topk_ratio)), tf.ones([num_sentences,], dtype=tf.int32))  # [num_sentences]                                                          
  
  candidate_labels = tf.argmax(candidate_scores, -1, output_type=tf.int32) #[num_sentences, max_num_candidates]
  candidate_labels = tf.multiply(candidate_labels, tf.to_int32(candidate_mask)) # [num_sentences, max_num_candidates]
  candidate_maxscores = tf.reduce_max(candidate_scores, reduction_indices=[-1]) # [num_sentences, max_num_candidates]
  # flat_candidate_labels = tf.reshape(candidate_labels, [-1]) # [num_sentences * max_num_candidates]
  # flat_candidate_maxscores = tf.reshape(candidate_maxscores, [-1])
  sorted_indices = tf.contrib.framework.argsort(candidate_maxscores, direction='DESCENDING', axis=-1)
  
  # zero = tf.constant(0, dtype=tf.int32)
  # where = tf.not_equal(candidate_labels, zero)
  # num_entities = tf.reduce_sum(tf.to_int32(where), reduction_indices=[-1])
  # max_entities = tf.reduce_max(num_entities)
  # predicted_indices = tf.where(where, out_type=tf.int32) #[num_entities]
  predicted_indices, num_entities = get_ner_spans(sorted_indices, candidate_labels, topk) #[num_sentences, max_num_entities]
  # predicted_indices = get_ner_spans(candidate_labels, topk) #[num_sentences, max_num_entities]
  predicted_starts = batch_gather(candidate_starts, predicted_indices)
  predicted_ends = batch_gather(candidate_ends, predicted_indices)
  predicted_scores = batch_gather(candidate_maxscores, predicted_indices)
  
  return predicted_starts, predicted_ends, predicted_scores, num_entities, predicted_indices
Exemplo n.º 24
0
  def lstm_contextualize(self, text_emb, text_len, text_len_mask):
    num_sentences = tf.shape(text_emb)[0]

    current_inputs = text_emb # [num_sentences, max_sentence_length, emb]

    for layer in range(self.config["contextualization_layers"]):
      with tf.variable_scope("layer_{}".format(layer)):
        with tf.variable_scope("fw_cell"):
          cell_fw = util.CustomLSTMCell(self.config["contextualization_size"], num_sentences, self.lstm_dropout)
        with tf.variable_scope("bw_cell"):
          cell_bw = util.CustomLSTMCell(self.config["contextualization_size"], num_sentences, self.lstm_dropout)
        state_fw = tf.contrib.rnn.LSTMStateTuple(tf.tile(cell_fw.initial_state.c, [num_sentences, 1]), tf.tile(cell_fw.initial_state.h, [num_sentences, 1]))
        state_bw = tf.contrib.rnn.LSTMStateTuple(tf.tile(cell_bw.initial_state.c, [num_sentences, 1]), tf.tile(cell_bw.initial_state.h, [num_sentences, 1]))

        (fw_outputs, bw_outputs), _ = tf.nn.bidirectional_dynamic_rnn(
          cell_fw=cell_fw,
          cell_bw=cell_bw,
          inputs=current_inputs,
          sequence_length=text_len,
          initial_state_fw=state_fw,
          initial_state_bw=state_bw)

        text_outputs = tf.concat([fw_outputs, bw_outputs], 2) # [num_sentences, max_sentence_length, emb]
        text_outputs = tf.nn.dropout(text_outputs, self.lstm_dropout)
        if layer > 0:
          highway_gates = tf.sigmoid(util.projection(text_outputs, util.shape(text_outputs, 2))) # [num_sentences, max_sentence_length, emb]
          text_outputs = highway_gates * text_outputs + (1 - highway_gates) * current_inputs
        current_inputs = text_outputs

    return self.flatten_emb_by_sentence(text_outputs, text_len_mask)
Exemplo n.º 25
0
  def get_span_emb(self, head_emb, context_outputs, span_starts, span_ends):
    span_emb_list = []

    span_start_emb = tf.gather(context_outputs, span_starts) # [k, emb]
    span_emb_list.append(span_start_emb)

    span_end_emb = tf.gather(context_outputs, span_ends) # [k, emb]
    span_emb_list.append(span_end_emb)

    span_width = 1 + span_ends - span_starts # [k]

    if self.config["use_features"]:
      span_width_index = span_width - 1 # [k]
      span_width_emb = tf.gather(tf.get_variable("span_width_embeddings", [self.config["max_span_width"], self.config["feature_size"]]), span_width_index) # [k, emb]
      span_width_emb = tf.nn.dropout(span_width_emb, self.dropout)
      span_emb_list.append(span_width_emb)

    if self.config["model_heads"]:
      span_indices = tf.expand_dims(tf.range(self.config["max_span_width"]), 0) + tf.expand_dims(span_starts, 1) # [k, max_span_width]
      span_indices = tf.minimum(util.shape(context_outputs, 0) - 1, span_indices) # [k, max_span_width]
      span_text_emb = tf.gather(head_emb, span_indices) # [k, max_span_width, emb]
      with tf.variable_scope("head_scores"):
        self.head_scores = util.projection(context_outputs, 1) # [num_words, 1]
      span_head_scores = tf.gather(self.head_scores, span_indices) # [k, max_span_width, 1]
      span_mask = tf.expand_dims(tf.sequence_mask(span_width, self.config["max_span_width"], dtype=tf.float32), 2) # [k, max_span_width, 1]
      span_head_scores += tf.log(span_mask) # [k, max_span_width, 1]
      span_attention = tf.nn.softmax(span_head_scores, 1) # [k, max_span_width, 1]
      span_head_emb = tf.reduce_sum(span_attention * span_text_emb, 1) # [k, emb]
      span_emb_list.append(span_head_emb)

    span_emb = tf.concat(span_emb_list, 1) # [k, emb]
    return span_emb # [k, emb]
Exemplo n.º 26
0
def coarse_to_fine_pruning(top_span_emb, top_span_mention_scores, c,
                           mention_doc_ids, dropout):
    k = util.shape(top_span_emb, 0)
    top_span_range = tf.range(k)  # [k]
    antecedent_offsets = tf.expand_dims(top_span_range, 1) - tf.expand_dims(
        top_span_range, 0)  # [k, k]
    antecedents_mask = antecedent_offsets >= 1  # [k, k]
    antecedents = tf.maximum(antecedent_offsets, 0)  # [k, k]
    target_doc_ids = tf.expand_dims(mention_doc_ids, 1)  # [k, k]
    antecedent_doc_ids = tf.gather(mention_doc_ids, antecedents)  # [k, k]
    antecedents_mask = tf.logical_and(
        tf.equal(target_doc_ids, antecedent_doc_ids),
        antecedents_mask)  # [k,k]
    fast_antecedent_scores = tf.expand_dims(
        top_span_mention_scores, 1) + tf.expand_dims(top_span_mention_scores,
                                                     0)  # [k, k]
    fast_antecedent_scores += tf.log(
        tf.to_float(antecedents_mask
                    ))  # [k, k] can not do masking at the end, need to sort
    fast_antecedent_scores += get_fast_antecedent_scores(
        top_span_emb, dropout)  # [k, k]

    _, top_antecedents = tf.nn.top_k(fast_antecedent_scores, c,
                                     sorted=False)  # [k, c]
    top_antecedents_mask = util.batch_gather(antecedents_mask,
                                             top_antecedents)  # [k, c]
    top_antecedents_mask = tf.squeeze(top_antecedents_mask, -1)
    top_fast_antecedent_scores = util.batch_gather(fast_antecedent_scores,
                                                   top_antecedents)  # [k, c]
    top_fast_antecedent_scores = tf.squeeze(top_fast_antecedent_scores, -1)
    top_antecedent_offsets = util.batch_gather(antecedent_offsets,
                                               top_antecedents)  # [k, c]
    top_antecedent_offsets = tf.squeeze(top_antecedent_offsets, -1)
    return top_antecedents, top_antecedents_mask, top_fast_antecedent_scores, top_antecedent_offsets
Exemplo n.º 27
0
    def get_feature_score(self, tmp_feature_emb, tmp_feature_name):
        k = util.shape(tmp_feature_emb, 0)
        c = util.shape(tmp_feature_emb, 1)
        repeated_feature_emb = tf.tile(tf.expand_dims(tmp_feature_emb, 1), [1, c, 1, 1])  # [k, c, c, feature_size]
        tiled_feature_emb = tf.tile(tf.expand_dims(tmp_feature_emb, 2), [1, 1, c, 1])  # [k, c, c, feature_size]

        final_feature = tf.concat([repeated_feature_emb, tiled_feature_emb, repeated_feature_emb * tiled_feature_emb],
                                  3)  # [k, c, c, feature_size*3]
        final_feature = tf.reshape(final_feature,
                                   [k, c * c, self.config["feature_size"] * 3])  # [k, c*c, feature_size*3]

        with tf.variable_scope(tmp_feature_name):
            tmp_feature_scores = util.ffnn(final_feature, self.config["ffnn_depth"], self.config["ffnn_size"], 1,
                                           self.dropout)  # [k, c*c, 1]
            tmp_feature_scores = tf.reshape(tmp_feature_scores, [k, c, c, 1])  # [k, c, c]
        return tmp_feature_scores
Exemplo n.º 28
0
  def get_mention_emb(self, text_emb, text_outputs, mention_starts, mention_ends):
    mention_emb_list = []

    mention_start_emb = tf.gather(text_outputs, mention_starts) # [num_mentions, emb]
    mention_emb_list.append(mention_start_emb)

    mention_end_emb = tf.gather(text_outputs, mention_ends) # [num_mentions, emb]
    mention_emb_list.append(mention_end_emb)

    mention_width = 1 + mention_ends - mention_starts # [num_mentions]
    if self.config["use_features"]:
      mention_width_index = mention_width - 1 # [num_mentions]
      mention_width_emb = tf.gather(tf.get_variable("mention_width_embeddings", [self.config["max_mention_width"], self.config["feature_size"]]), mention_width_index) # [num_mentions, emb]
      mention_width_emb = tf.nn.dropout(mention_width_emb, self.dropout)
      mention_emb_list.append(mention_width_emb)

    if self.config["model_heads"]:
      mention_indices = tf.expand_dims(tf.range(self.config["max_mention_width"]), 0) + tf.expand_dims(mention_starts, 1) # [num_mentions, max_mention_width]
      mention_indices = tf.minimum(util.shape(text_outputs, 0) - 1, mention_indices) # [num_mentions, max_mention_width]
      mention_text_emb = tf.gather(text_emb, mention_indices) # [num_mentions, max_mention_width, emb]
      self.head_scores = util.projection(text_outputs, 1) # [num_words, 1]
      mention_head_scores = tf.gather(self.head_scores, mention_indices) # [num_mentions, max_mention_width, 1]
      mention_mask = tf.expand_dims(tf.sequence_mask(mention_width, self.config["max_mention_width"], dtype=tf.float32), 2) # [num_mentions, max_mention_width, 1]
      mention_attention = tf.nn.softmax(mention_head_scores + tf.log(mention_mask), dim=1) # [num_mentions, max_mention_width, 1]
      mention_head_emb = tf.reduce_sum(mention_attention * mention_text_emb, 1) # [num_mentions, emb]
      mention_emb_list.append(mention_head_emb)

    mention_emb = tf.concat(mention_emb_list, 1) # [num_mentions, emb]
    return mention_emb
Exemplo n.º 29
0
  def get_mention_emb(self, text_emb, text_outputs, mention_starts, mention_ends):
    mention_emb_list = []

    mention_start_emb = tf.gather(text_outputs, mention_starts) # [num_mentions, emb]
    mention_emb_list.append(mention_start_emb)

    mention_end_emb = tf.gather(text_outputs, mention_ends) # [num_mentions, emb]
    mention_emb_list.append(mention_end_emb)

    mention_width = 1 + mention_ends - mention_starts # [num_mentions]
    if self.config["use_features"]:
      mention_width_index = mention_width - 1 # [num_mentions]
      mention_width_emb = tf.gather(tf.get_variable("mention_width_embeddings", [self.config["max_mention_width"], self.config["feature_size"]]), mention_width_index) # [num_mentions, emb]
      mention_width_emb = tf.nn.dropout(mention_width_emb, self.dropout)
      mention_emb_list.append(mention_width_emb)

    if self.config["model_heads"]:
      mention_indices = tf.expand_dims(tf.range(self.config["max_mention_width"]), 0) + tf.expand_dims(mention_starts, 1) # [num_mentions, max_mention_width]
      mention_indices = tf.minimum(util.shape(text_outputs, 0) - 1, mention_indices) # [num_mentions, max_mention_width]
      mention_text_emb = tf.gather(text_emb, mention_indices) # [num_mentions, max_mention_width, emb]
      self.head_scores = util.projection(text_outputs, 1) # [num_words, 1]
      mention_head_scores = tf.gather(self.head_scores, mention_indices) # [num_mentions, max_mention_width, 1]
      mention_mask = tf.expand_dims(tf.sequence_mask(mention_width, self.config["max_mention_width"], dtype=tf.float32), 2) # [num_mentions, max_mention_width, 1]
      mention_attention = tf.nn.softmax(mention_head_scores + tf.log(mention_mask), dim=1) # [num_mentions, max_mention_width, 1]
      mention_head_emb = tf.reduce_sum(mention_attention * mention_text_emb, 1) # [num_mentions, emb]
      mention_emb_list.append(mention_head_emb)

    mention_emb = tf.concat(mention_emb_list, 1) # [num_mentions, emb]
    return mention_emb
Exemplo n.º 30
0
def get_span_candidates(text_len, max_sentence_length, max_mention_width):
    """
  Args:
    text_len: Tensor of [num_sentences,]
    max_sentence_length: Integer scalar.
    max_mention_width: Integer.
  """
    num_sentences = util.shape(text_len, 0)
    candidate_starts = tf.tile(
        tf.expand_dims(tf.expand_dims(tf.range(max_sentence_length), 0), 1),
        [num_sentences, max_mention_width, 1
         ])  # [num_sentences, max_mention_width, max_sentence_length]
    candidate_widths = tf.expand_dims(
        tf.expand_dims(tf.range(max_mention_width), 0),
        2)  # [1, max_mention_width, 1]
    candidate_ends = candidate_starts + candidate_widths  # [num_sentences, max_mention_width, max_sentence_length]

    candidate_starts = tf.reshape(
        candidate_starts,
        [num_sentences, max_mention_width * max_sentence_length])
    candidate_ends = tf.reshape(
        candidate_ends,
        [num_sentences, max_mention_width * max_sentence_length])
    candidate_mask = tf.less(
        candidate_ends,
        tf.tile(
            tf.expand_dims(text_len, 1),
            [1, max_mention_width * max_sentence_length
             ]))  # [num_sentences, max_mention_width * max_sentence_length]

    # Mask to avoid indexing error.
    candidate_starts = tf.multiply(candidate_starts,
                                   tf.to_int32(candidate_mask))
    candidate_ends = tf.multiply(candidate_ends, tf.to_int32(candidate_mask))
    return candidate_starts, candidate_ends, candidate_mask
Exemplo n.º 31
0
def get_softmax_loss(scores, labels, candidate_mask):
    """Softmax loss with 1-D masking. (on Unary factors)
  Args:
    scores: [num_sentences, max_num_candidates, num_labels]
    labels: [num_sentences, max_num_candidates]
    candidate_mask: [num_sentences, max_num_candidates]
  """
    max_num_candidates = util.shape(scores, 1)
    num_labels = util.shape(scores, 2)
    loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
        labels=tf.reshape(labels, [-1]),
        logits=tf.reshape(scores, [-1, num_labels]),
        name="softmax_loss")  # [num_sentences, max_num_candidates]
    loss = tf.boolean_mask(loss, tf.reshape(candidate_mask, [-1]))
    loss.set_shape([None])
    return loss
Exemplo n.º 32
0
  def get_antecedent_scores(self, mention_emb, mention_scores, antecedents, antecedents_len, mention_starts, mention_ends, mention_speaker_ids, genre_emb):
    num_mentions = util.shape(mention_emb, 0)
    max_antecedents = util.shape(antecedents, 1)

    feature_emb_list = []

    if self.config["use_metadata"]:
      antecedent_speaker_ids = tf.gather(mention_speaker_ids, antecedents) # [num_mentions, max_ant]
      same_speaker = tf.equal(tf.expand_dims(mention_speaker_ids, 1), antecedent_speaker_ids) # [num_mentions, max_ant]
      speaker_pair_emb = tf.gather(tf.get_variable("same_speaker_emb", [2, self.config["feature_size"]]), tf.to_int32(same_speaker)) # [num_mentions, max_ant, emb]
      feature_emb_list.append(speaker_pair_emb)

      tiled_genre_emb = tf.tile(tf.expand_dims(tf.expand_dims(genre_emb, 0), 0), [num_mentions, max_antecedents, 1]) # [num_mentions, max_ant, emb]
      feature_emb_list.append(tiled_genre_emb)

    if self.config["use_features"]:
      target_indices = tf.range(num_mentions) # [num_mentions]
      mention_distance = tf.expand_dims(target_indices, 1) - antecedents # [num_mentions, max_ant]
      mention_distance_bins = coref_ops.distance_bins(mention_distance) # [num_mentions, max_ant]
      mention_distance_bins.set_shape([None, None])
      mention_distance_emb = tf.gather(tf.get_variable("mention_distance_emb", [10, self.config["feature_size"]]), mention_distance_bins) # [num_mentions, max_ant]
      feature_emb_list.append(mention_distance_emb)

    feature_emb = tf.concat(feature_emb_list, 2) # [num_mentions, max_ant, emb]
    feature_emb = tf.nn.dropout(feature_emb, self.dropout) # [num_mentions, max_ant, emb]

    antecedent_emb = tf.gather(mention_emb, antecedents) # [num_mentions, max_ant, emb]
    target_emb_tiled = tf.tile(tf.expand_dims(mention_emb, 1), [1, max_antecedents, 1]) # [num_mentions, max_ant, emb]
    similarity_emb = antecedent_emb * target_emb_tiled # [num_mentions, max_ant, emb]

    pair_emb = tf.concat([target_emb_tiled, antecedent_emb, similarity_emb, feature_emb], 2) # [num_mentions, max_ant, emb]

    with tf.variable_scope("iteration"):
      with tf.variable_scope("antecedent_scoring"):
        antecedent_scores = util.ffnn(pair_emb, self.config["ffnn_depth"], self.config["ffnn_size"], 1, self.dropout) # [num_mentions, max_ant, 1]
    antecedent_scores = tf.squeeze(antecedent_scores, 2) # [num_mentions, max_ant]

    antecedent_mask = tf.log(tf.sequence_mask(antecedents_len, max_antecedents, dtype=tf.float32)) # [num_mentions, max_ant]
    antecedent_scores += antecedent_mask # [num_mentions, max_ant]

    antecedent_scores += tf.expand_dims(mention_scores, 1) + tf.gather(mention_scores, antecedents) # [num_mentions, max_ant]
    antecedent_scores = tf.concat([tf.zeros([util.shape(mention_scores, 0), 1]), antecedent_scores], 1) # [num_mentions, max_ant + 1]
    return antecedent_scores  # [num_mentions, max_ant + 1]
Exemplo n.º 33
0
  def flatten_emb_by_sentence(self, emb, text_len_mask):
    num_sentences = tf.shape(emb)[0]
    max_sentence_length = tf.shape(emb)[1]

    emb_rank = len(emb.get_shape())
    if emb_rank  == 2:
      flattened_emb = tf.reshape(emb, [num_sentences * max_sentence_length])
    elif emb_rank == 3:
      flattened_emb = tf.reshape(emb, [num_sentences * max_sentence_length, util.shape(emb, 2)])
    else:
      raise ValueError("Unsupported rank: {}".format(emb_rank))
    return tf.boolean_mask(flattened_emb, text_len_mask)
Exemplo n.º 34
0
  def get_predictions_and_loss(self, word_emb, char_index, text_len, speaker_ids, genre, is_training, gold_starts, gold_ends, cluster_ids):
    self.dropout = 1 - (tf.to_float(is_training) * self.config["dropout_rate"])
    self.lexical_dropout = 1 - (tf.to_float(is_training) * self.config["lexical_dropout_rate"])

    num_sentences = tf.shape(word_emb)[0]
    max_sentence_length = tf.shape(word_emb)[1]

    text_emb_list = [word_emb]

    if self.config["char_embedding_size"] > 0:
      char_emb = tf.gather(tf.get_variable("char_embeddings", [len(self.char_dict), self.config["char_embedding_size"]]), char_index) # [num_sentences, max_sentence_length, max_word_length, emb]
      flattened_char_emb = tf.reshape(char_emb, [num_sentences * max_sentence_length, util.shape(char_emb, 2), util.shape(char_emb, 3)]) # [num_sentences * max_sentence_length, max_word_length, emb]
      flattened_aggregated_char_emb = util.cnn(flattened_char_emb, self.config["filter_widths"], self.config["filter_size"]) # [num_sentences * max_sentence_length, emb]
      aggregated_char_emb = tf.reshape(flattened_aggregated_char_emb, [num_sentences, max_sentence_length, util.shape(flattened_aggregated_char_emb, 1)]) # [num_sentences, max_sentence_length, emb]
      text_emb_list.append(aggregated_char_emb)

    text_emb = tf.concat(text_emb_list, 2)
    text_emb = tf.nn.dropout(text_emb, self.lexical_dropout)

    text_len_mask = tf.sequence_mask(text_len, maxlen=max_sentence_length)
    text_len_mask = tf.reshape(text_len_mask, [num_sentences * max_sentence_length])

    text_outputs = self.encode_sentences(text_emb, text_len, text_len_mask)
    text_outputs = tf.nn.dropout(text_outputs, self.dropout)

    genre_emb = tf.gather(tf.get_variable("genre_embeddings", [len(self.genres), self.config["feature_size"]]), genre) # [emb]

    sentence_indices = tf.tile(tf.expand_dims(tf.range(num_sentences), 1), [1, max_sentence_length]) # [num_sentences, max_sentence_length]
    flattened_sentence_indices = self.flatten_emb_by_sentence(sentence_indices, text_len_mask) # [num_words]
    flattened_text_emb = self.flatten_emb_by_sentence(text_emb, text_len_mask) # [num_words]

    candidate_starts, candidate_ends = coref_ops.spans(
      sentence_indices=flattened_sentence_indices,
      max_width=self.max_mention_width)
    candidate_starts.set_shape([None])
    candidate_ends.set_shape([None])

    candidate_mention_emb = self.get_mention_emb(flattened_text_emb, text_outputs, candidate_starts, candidate_ends) # [num_candidates, emb]
    candidate_mention_scores =  self.get_mention_scores(candidate_mention_emb) # [num_mentions, 1]
    candidate_mention_scores = tf.squeeze(candidate_mention_scores, 1) # [num_mentions]

    k = tf.to_int32(tf.floor(tf.to_float(tf.shape(text_outputs)[0]) * self.config["mention_ratio"]))
    predicted_mention_indices = coref_ops.extract_mentions(candidate_mention_scores, candidate_starts, candidate_ends, k) # ([k], [k])
    predicted_mention_indices.set_shape([None])

    mention_starts = tf.gather(candidate_starts, predicted_mention_indices) # [num_mentions]
    mention_ends = tf.gather(candidate_ends, predicted_mention_indices) # [num_mentions]
    mention_emb = tf.gather(candidate_mention_emb, predicted_mention_indices) # [num_mentions, emb]
    mention_scores = tf.gather(candidate_mention_scores, predicted_mention_indices) # [num_mentions]

    mention_start_emb = tf.gather(text_outputs, mention_starts) # [num_mentions, emb]
    mention_end_emb = tf.gather(text_outputs, mention_ends) # [num_mentions, emb]
    mention_speaker_ids = tf.gather(speaker_ids, mention_starts) # [num_mentions]

    max_antecedents = self.config["max_antecedents"]
    antecedents, antecedent_labels, antecedents_len = coref_ops.antecedents(mention_starts, mention_ends, gold_starts, gold_ends, cluster_ids, max_antecedents) # ([num_mentions, max_ant], [num_mentions, max_ant + 1], [num_mentions]
    antecedents.set_shape([None, None])
    antecedent_labels.set_shape([None, None])
    antecedents_len.set_shape([None])

    antecedent_scores = self.get_antecedent_scores(mention_emb, mention_scores, antecedents, antecedents_len, mention_starts, mention_ends, mention_speaker_ids, genre_emb) # [num_mentions, max_ant + 1]

    loss = self.softmax_loss(antecedent_scores, antecedent_labels) # [num_mentions]
    loss = tf.reduce_sum(loss) # []

    return [candidate_starts, candidate_ends, candidate_mention_scores, mention_starts, mention_ends, antecedents, antecedent_scores], loss