def __init__( self, vocab: Vocabulary, trigger_feedforward: FeedForward, trigger_candidate_feedforward: FeedForward, mention_feedforward: FeedForward, # Used if entity beam is off. argument_feedforward: FeedForward, context_attention: BilinearMatrixAttention, trigger_attention: Seq2SeqEncoder, span_prop: SpanProp, cls_projection: FeedForward, feature_size: int, trigger_spans_per_word: float, argument_spans_per_word: float, loss_weights, trigger_attention_context: bool, event_args_use_trigger_labels: bool, event_args_use_ner_labels: bool, event_args_label_emb: int, shared_attention_context: bool, label_embedding_method: str, event_args_label_predictor: str, event_args_gold_candidates: bool = False, # If True, use gold argument candidates. context_window: int = 0, softmax_correction: bool = False, initializer: InitializerApplicator = InitializerApplicator(), positive_label_weight: float = 1.0, entity_beam: bool = False, regularizer: Optional[RegularizerApplicator] = None) -> None: super(EventExtractor, self).__init__(vocab, regularizer) self._n_ner_labels = vocab.get_vocab_size("ner_labels") self._n_trigger_labels = vocab.get_vocab_size("trigger_labels") self._n_argument_labels = vocab.get_vocab_size("argument_labels") # Embeddings for trigger labels and ner labels, to be used by argument scorer. # These will be either one-hot encodings or learned embeddings, depending on "kind". self._ner_label_emb = make_embedder(kind=label_embedding_method, num_embeddings=self._n_ner_labels, embedding_dim=event_args_label_emb) self._trigger_label_emb = make_embedder( kind=label_embedding_method, num_embeddings=self._n_trigger_labels, embedding_dim=event_args_label_emb) self._label_embedding_method = label_embedding_method # Weight on trigger labeling and argument labeling. self._loss_weights = loss_weights.as_dict() # Trigger candidate scorer. null_label = vocab.get_token_index("", "trigger_labels") assert null_label == 0 # If not, the dummy class won't correspond to the null label. self._trigger_scorer = torch.nn.Sequential( TimeDistributed(trigger_feedforward), TimeDistributed( torch.nn.Linear(trigger_feedforward.get_output_dim(), self._n_trigger_labels - 1))) self._trigger_attention_context = trigger_attention_context if self._trigger_attention_context: self._trigger_attention = trigger_attention # Make pruners. If `entity_beam` is true, use NER and trigger scorers to construct the beam # and only keep candidates that the model predicts are actual entities or triggers. self._mention_pruner = make_pruner( mention_feedforward, entity_beam=entity_beam, gold_beam=event_args_gold_candidates) self._trigger_pruner = make_pruner(trigger_candidate_feedforward, entity_beam=entity_beam, gold_beam=False) # Argument scorer. self._event_args_use_trigger_labels = event_args_use_trigger_labels # If True, use trigger labels. self._event_args_use_ner_labels = event_args_use_ner_labels # If True, use ner labels to predict args. assert event_args_label_predictor in [ "hard", "softmax", "gold" ] # Method for predicting labels at test time. self._event_args_label_predictor = event_args_label_predictor self._event_args_gold_candidates = event_args_gold_candidates # If set to True, then construct a context vector from a bilinear attention over the trigger # / argument pair embeddings and the text. self._context_window = context_window # If greater than 0, concatenate context as features. self._argument_feedforward = argument_feedforward self._argument_scorer = torch.nn.Linear( argument_feedforward.get_output_dim(), self._n_argument_labels) # Distance embeddings. self._num_distance_buckets = 10 # Just use 10 which is the default. self._distance_embedding = Embedding(self._num_distance_buckets, feature_size) # Class token projection. self._cls_projection = cls_projection self._cls_n_triggers = torch.nn.Linear( self._cls_projection.get_output_dim(), 5) self._cls_event_types = torch.nn.Linear( self._cls_projection.get_output_dim(), self._n_trigger_labels - 1) self._trigger_spans_per_word = trigger_spans_per_word self._argument_spans_per_word = argument_spans_per_word # Context attention for event argument scorer. self._shared_attention_context = shared_attention_context if self._shared_attention_context: self._shared_attention_context_module = context_attention # Span propagation object. # TODO(dwadden) initialize with `from_params` instead if this ends up working. self._span_prop = span_prop self._span_prop._trig_arg_embedder = self._compute_trig_arg_embeddings self._span_prop._argument_scorer = self._compute_argument_scores # Softmax correction parameters. self._softmax_correction = softmax_correction self._softmax_log_temp = torch.nn.Parameter( torch.zeros([1, 1, 1, self._n_argument_labels])) self._softmax_log_multiplier = torch.nn.Parameter( torch.zeros([1, 1, 1, self._n_argument_labels])) # TODO(dwadden) Add metrics. self._metrics = EventMetrics() self._argument_stats = ArgumentStats() self._trigger_loss = torch.nn.CrossEntropyLoss(reduction="sum") # TODO(dwadden) add loss weights. self._argument_loss = torch.nn.CrossEntropyLoss(reduction="sum", ignore_index=-1) initializer(self)
class EventExtractor(Model): """ Event extraction for DyGIE. """ # TODO(dwadden) add option to make `mention_feedforward` be the NER tagger. def __init__( self, vocab: Vocabulary, trigger_feedforward: FeedForward, trigger_candidate_feedforward: FeedForward, mention_feedforward: FeedForward, # Used if entity beam is off. argument_feedforward: FeedForward, context_attention: BilinearMatrixAttention, trigger_attention: Seq2SeqEncoder, span_prop: SpanProp, cls_projection: FeedForward, feature_size: int, trigger_spans_per_word: float, argument_spans_per_word: float, loss_weights, trigger_attention_context: bool, event_args_use_trigger_labels: bool, event_args_use_ner_labels: bool, event_args_label_emb: int, shared_attention_context: bool, label_embedding_method: str, event_args_label_predictor: str, event_args_gold_candidates: bool = False, # If True, use gold argument candidates. context_window: int = 0, softmax_correction: bool = False, initializer: InitializerApplicator = InitializerApplicator(), positive_label_weight: float = 1.0, entity_beam: bool = False, regularizer: Optional[RegularizerApplicator] = None) -> None: super(EventExtractor, self).__init__(vocab, regularizer) self._n_ner_labels = vocab.get_vocab_size("ner_labels") self._n_trigger_labels = vocab.get_vocab_size("trigger_labels") self._n_argument_labels = vocab.get_vocab_size("argument_labels") # Embeddings for trigger labels and ner labels, to be used by argument scorer. # These will be either one-hot encodings or learned embeddings, depending on "kind". self._ner_label_emb = make_embedder(kind=label_embedding_method, num_embeddings=self._n_ner_labels, embedding_dim=event_args_label_emb) self._trigger_label_emb = make_embedder( kind=label_embedding_method, num_embeddings=self._n_trigger_labels, embedding_dim=event_args_label_emb) self._label_embedding_method = label_embedding_method # Weight on trigger labeling and argument labeling. self._loss_weights = loss_weights.as_dict() # Trigger candidate scorer. null_label = vocab.get_token_index("", "trigger_labels") assert null_label == 0 # If not, the dummy class won't correspond to the null label. self._trigger_scorer = torch.nn.Sequential( TimeDistributed(trigger_feedforward), TimeDistributed( torch.nn.Linear(trigger_feedforward.get_output_dim(), self._n_trigger_labels - 1))) self._trigger_attention_context = trigger_attention_context if self._trigger_attention_context: self._trigger_attention = trigger_attention # Make pruners. If `entity_beam` is true, use NER and trigger scorers to construct the beam # and only keep candidates that the model predicts are actual entities or triggers. self._mention_pruner = make_pruner( mention_feedforward, entity_beam=entity_beam, gold_beam=event_args_gold_candidates) self._trigger_pruner = make_pruner(trigger_candidate_feedforward, entity_beam=entity_beam, gold_beam=False) # Argument scorer. self._event_args_use_trigger_labels = event_args_use_trigger_labels # If True, use trigger labels. self._event_args_use_ner_labels = event_args_use_ner_labels # If True, use ner labels to predict args. assert event_args_label_predictor in [ "hard", "softmax", "gold" ] # Method for predicting labels at test time. self._event_args_label_predictor = event_args_label_predictor self._event_args_gold_candidates = event_args_gold_candidates # If set to True, then construct a context vector from a bilinear attention over the trigger # / argument pair embeddings and the text. self._context_window = context_window # If greater than 0, concatenate context as features. self._argument_feedforward = argument_feedforward self._argument_scorer = torch.nn.Linear( argument_feedforward.get_output_dim(), self._n_argument_labels) # Distance embeddings. self._num_distance_buckets = 10 # Just use 10 which is the default. self._distance_embedding = Embedding(self._num_distance_buckets, feature_size) # Class token projection. self._cls_projection = cls_projection self._cls_n_triggers = torch.nn.Linear( self._cls_projection.get_output_dim(), 5) self._cls_event_types = torch.nn.Linear( self._cls_projection.get_output_dim(), self._n_trigger_labels - 1) self._trigger_spans_per_word = trigger_spans_per_word self._argument_spans_per_word = argument_spans_per_word # Context attention for event argument scorer. self._shared_attention_context = shared_attention_context if self._shared_attention_context: self._shared_attention_context_module = context_attention # Span propagation object. # TODO(dwadden) initialize with `from_params` instead if this ends up working. self._span_prop = span_prop self._span_prop._trig_arg_embedder = self._compute_trig_arg_embeddings self._span_prop._argument_scorer = self._compute_argument_scores # Softmax correction parameters. self._softmax_correction = softmax_correction self._softmax_log_temp = torch.nn.Parameter( torch.zeros([1, 1, 1, self._n_argument_labels])) self._softmax_log_multiplier = torch.nn.Parameter( torch.zeros([1, 1, 1, self._n_argument_labels])) # TODO(dwadden) Add metrics. self._metrics = EventMetrics() self._argument_stats = ArgumentStats() self._trigger_loss = torch.nn.CrossEntropyLoss(reduction="sum") # TODO(dwadden) add loss weights. self._argument_loss = torch.nn.CrossEntropyLoss(reduction="sum", ignore_index=-1) initializer(self) @overrides def forward( self, # type: ignore trigger_mask, trigger_embeddings, spans, span_mask, span_embeddings, # TODO(dwadden) add type. cls_embeddings, sentence_lengths, output_ner, # Needed if we're using entity beam approach. trigger_labels, argument_labels, ner_labels, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: """ TODO(dwadden) Write documentation. The trigger embeddings are just the contextualized token embeddings, and the trigger mask is the text mask. For the arguments, we consider all the spans. """ cls_projected = self._cls_projection(cls_embeddings) auxiliary_loss = self._compute_auxiliary_loss(cls_projected, trigger_labels, trigger_mask) ner_scores = output_ner["ner_scores"] predicted_ner = output_ner["predicted_ner"] # Compute trigger scores. trigger_scores = self._compute_trigger_scores(trigger_embeddings, cls_projected, trigger_mask) _, predicted_triggers = trigger_scores.max(-1) # Get trigger candidates for event argument labeling. num_trigs_to_keep = torch.floor(sentence_lengths.float() * self._trigger_spans_per_word).long() num_trigs_to_keep = torch.max(num_trigs_to_keep, torch.ones_like(num_trigs_to_keep)) num_trigs_to_keep = torch.min(num_trigs_to_keep, 15 * torch.ones_like(num_trigs_to_keep)) (top_trig_embeddings, top_trig_mask, top_trig_indices, top_trig_scores, num_trigs_kept) = self._trigger_pruner(trigger_embeddings, trigger_mask, num_trigs_to_keep, trigger_scores) top_trig_mask = top_trig_mask.unsqueeze(-1) # Compute the number of argument spans to keep. num_arg_spans_to_keep = torch.floor( sentence_lengths.float() * self._argument_spans_per_word).long() num_arg_spans_to_keep = torch.max( num_arg_spans_to_keep, torch.ones_like(num_arg_spans_to_keep)) num_arg_spans_to_keep = torch.min( num_arg_spans_to_keep, 30 * torch.ones_like(num_arg_spans_to_keep)) # If we're using gold event arguments, include the gold labels. gold_labels = ner_labels if self._event_args_gold_candidates else None (top_arg_embeddings, top_arg_mask, top_arg_indices, top_arg_scores, num_arg_spans_kept) = self._mention_pruner(span_embeddings, span_mask, num_arg_spans_to_keep, ner_scores, gold_labels) top_arg_mask = top_arg_mask.unsqueeze(-1) top_arg_spans = util.batched_index_select(spans, top_arg_indices) # Collect trigger and ner labels, in case they're included as features to the argument # classifier. # At train time, use the gold labels. At test time, use the labels predicted by the model, # or gold if specified. if self.training or self._event_args_label_predictor == "gold": top_trig_labels = trigger_labels.gather(1, top_trig_indices) top_ner_labels = ner_labels.gather(1, top_arg_indices) else: # Hard predictions. if self._event_args_label_predictor == "hard": top_trig_labels = predicted_triggers.gather( 1, top_trig_indices) top_ner_labels = predicted_ner.gather(1, top_arg_indices) # Softmax predictions. else: softmax_triggers = trigger_scores.softmax(dim=-1) top_trig_labels = util.batched_index_select( softmax_triggers, top_trig_indices) softmax_ner = ner_scores.softmax(dim=-1) top_ner_labels = util.batched_index_select( softmax_ner, top_arg_indices) # Make a dict of all arguments that are needed to make trigger / argument pair embeddings. trig_arg_emb_dict = dict(cls_projected=cls_projected, top_trig_labels=top_trig_labels, top_ner_labels=top_ner_labels, top_trig_indices=top_trig_indices, top_arg_spans=top_arg_spans, text_emb=trigger_embeddings, text_mask=trigger_mask) # Run span graph propagation, if asked for if self._span_prop._n_span_prop > 0: top_trig_embeddings, top_arg_embeddings = self._span_prop( top_trig_embeddings, top_arg_embeddings, top_trig_mask, top_arg_mask, top_trig_scores, top_arg_scores, trig_arg_emb_dict) top_trig_indices_repeat = (top_trig_indices.unsqueeze(-1).repeat( 1, 1, top_trig_embeddings.size(-1))) updated_trig_embeddings = trigger_embeddings.scatter( 1, top_trig_indices_repeat, top_trig_embeddings) # Recompute the trigger scores. trigger_scores = self._compute_trigger_scores( updated_trig_embeddings, cls_projected, trigger_mask) _, predicted_triggers = trigger_scores.max(-1) trig_arg_embeddings = self._compute_trig_arg_embeddings( top_trig_embeddings, top_arg_embeddings, **trig_arg_emb_dict) argument_scores = self._compute_argument_scores( trig_arg_embeddings, top_trig_scores, top_arg_scores, top_arg_mask) _, predicted_arguments = argument_scores.max(-1) predicted_arguments -= 1 # The null argument has label -1. output_dict = { "top_trigger_indices": top_trig_indices, "top_argument_spans": top_arg_spans, "trigger_scores": trigger_scores, "argument_scores": argument_scores, "predicted_triggers": predicted_triggers, "predicted_arguments": predicted_arguments, "num_triggers_kept": num_trigs_kept, "num_argument_spans_kept": num_arg_spans_kept, "sentence_lengths": sentence_lengths } # Evaluate loss and F1 if labels were provided. if trigger_labels is not None and argument_labels is not None: # Compute the loss for both triggers and arguments. trigger_loss = self._get_trigger_loss(trigger_scores, trigger_labels, trigger_mask) gold_arguments = self._get_pruned_gold_arguments( argument_labels, top_trig_indices, top_arg_indices, top_trig_mask, top_arg_mask) argument_loss = self._get_argument_loss(argument_scores, gold_arguments) # Compute F1. predictions = self.decode(output_dict)["decoded_events"] assert len(predictions) == len( metadata) # Make sure length of predictions is right. self._metrics(predictions, metadata) self._argument_stats(predictions) loss = (self._loss_weights["trigger"] * trigger_loss + self._loss_weights["arguments"] * argument_loss + 0.05 * auxiliary_loss) output_dict["loss"] = loss return output_dict @overrides def decode(self, output_dict): """ Take the output and convert it into a list of dicts. Each entry is a sentence. Each key is a pair of span indices for that sentence, and each value is the relation label on that span pair. """ outputs = fields_to_batches( {k: v.detach().cpu() for k, v in output_dict.items()}) res = [] # Collect predictions for each sentence in minibatch. for output in outputs: decoded_trig = self._decode_trigger(output) decoded_args, decoded_args_with_scores = self._decode_arguments( output, decoded_trig) entry = dict(trigger_dict=decoded_trig, argument_dict=decoded_args, argument_dict_with_scores=decoded_args_with_scores) res.append(entry) output_dict["decoded_events"] = res return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: f1_metrics = self._metrics.get_metric(reset) argument_stats = self._argument_stats.get_metric(reset) res = {} res.update(f1_metrics) res.update(argument_stats) return res def _decode_trigger(self, output): trigger_dict = {} for i in range(output["sentence_lengths"]): trig_label = output["predicted_triggers"][i].item() if trig_label > 0: trigger_dict[i] = self.vocab.get_token_from_index( trig_label, namespace="trigger_labels") return trigger_dict def _decode_arguments(self, output, decoded_trig): argument_dict = {} argument_dict_with_scores = {} for i, j in itertools.product(range(output["num_triggers_kept"]), range( output["num_argument_spans_kept"])): trig_ix = output["top_trigger_indices"][i].item() arg_span = tuple(output["top_argument_spans"][j].tolist()) arg_label = output["predicted_arguments"][i, j].item() # Only include the argument if its putative trigger is predicted as a real trigger. if arg_label >= 0 and trig_ix in decoded_trig: arg_score = output["argument_scores"][i, j, arg_label + 1].item() label_name = self.vocab.get_token_from_index( arg_label, namespace="argument_labels") argument_dict[(trig_ix, arg_span)] = label_name # Keep around a version with the predicted labels and their scores, for debugging # purposes. argument_dict_with_scores[(trig_ix, arg_span)] = (label_name, arg_score) return argument_dict, argument_dict_with_scores def _compute_auxiliary_loss(self, cls_projected, trigger_labels, trigger_mask): num_triggers = ((trigger_labels > 0) * trigger_mask.byte()).sum(dim=1) # Truncate at 4. num_triggers = torch.min(num_triggers, 4 * torch.ones_like(num_triggers)) predicted_num_triggers = self._cls_n_triggers(cls_projected) num_trigger_loss = F.cross_entropy(predicted_num_triggers, num_triggers, weight=torch.tensor( [1, 3, 3, 3, 3], device=trigger_labels.device, dtype=torch.float), reduction="sum") label_present = [ torch.any(trigger_labels == i, dim=1).unsqueeze(1) for i in range(1, self._n_trigger_labels) ] label_present = torch.cat(label_present, dim=1) if cls_projected.device.type != "cpu": label_present = label_present.cuda(cls_projected.device) predicted_event_type_logits = self._cls_event_types(cls_projected) trigger_label_loss = F.binary_cross_entropy_with_logits( predicted_event_type_logits, label_present.float(), reduction="sum") return num_trigger_loss + trigger_label_loss def _compute_trigger_scores(self, trigger_embeddings, cls_projected, trigger_mask): """ Compute trigger scores for all tokens. """ cls_repeat = cls_projected.unsqueeze(dim=1).repeat( 1, trigger_embeddings.size(1), 1) trigger_embeddings = torch.cat([trigger_embeddings, cls_repeat], dim=-1) if self._trigger_attention_context: context = self._trigger_attention(trigger_embeddings, trigger_mask) trigger_embeddings = torch.cat([trigger_embeddings, context], dim=2) trigger_scores = self._trigger_scorer(trigger_embeddings) # Give large negative scores to masked-out elements. mask = trigger_mask.unsqueeze(-1) trigger_scores = util.replace_masked_values(trigger_scores, mask, -1e20) dummy_dims = [trigger_scores.size(0), trigger_scores.size(1), 1] dummy_scores = trigger_scores.new_zeros(*dummy_dims) trigger_scores = torch.cat((dummy_scores, trigger_scores), -1) # Give large negative scores to the masked-out values. return trigger_scores def _compute_trig_arg_embeddings(self, top_trig_embeddings, top_arg_embeddings, cls_projected, top_trig_labels, top_ner_labels, top_trig_indices, top_arg_spans, text_emb, text_mask): """ Create trigger / argument pair embeddings, consisting of: - The embeddings of the trigger and argument pair. - Optionally, the embeddings of the trigger and argument labels. - Optionally, embeddings of the words surrounding the trigger and argument. """ trig_emb_extras = [] arg_emb_extras = [] if self._context_window > 0: # Include words in a window around trigger and argument. # For triggers, the span start and end indices are the same. trigger_context = self._get_context(top_trig_indices, top_trig_indices, text_emb) argument_context = self._get_context(top_arg_spans[:, :, 0], top_arg_spans[:, :, 1], text_emb) trig_emb_extras.append(trigger_context) arg_emb_extras.append(argument_context) # TODO(dwadden) refactor this. Way too many conditionals. if self._event_args_use_trigger_labels: if self._event_args_label_predictor == "softmax" and not self.training: if self._label_embedding_method == "one_hot": # If we're using one-hot encoding, just return the scores for each class. top_trig_embs = top_trig_labels else: # Otherwise take the average of the embeddings, weighted by softmax scores. top_trig_embs = torch.matmul( top_trig_labels, self._trigger_label_emb.weight) trig_emb_extras.append(top_trig_embs) else: trig_emb_extras.append( self._trigger_label_emb(top_trig_labels)) if self._event_args_use_ner_labels: if self._event_args_label_predictor == "softmax" and not self.training: # Same deal as for trigger labels. if self._label_embedding_method == "one_hot": top_ner_embs = top_ner_labels else: top_ner_embs = torch.matmul(top_ner_labels, self._ner_label_emb.weight) arg_emb_extras.append(top_ner_embs) else: # Otherwise, just return the embeddings. arg_emb_extras.append(self._ner_label_emb(top_ner_labels)) num_trigs = top_trig_embeddings.size(1) num_args = top_arg_embeddings.size(1) trig_emb_expanded = top_trig_embeddings.unsqueeze(2) trig_emb_tiled = trig_emb_expanded.repeat(1, 1, num_args, 1) arg_emb_expanded = top_arg_embeddings.unsqueeze(1) arg_emb_tiled = arg_emb_expanded.repeat(1, num_trigs, 1, 1) distance_embeddings = self._compute_distance_embeddings( top_trig_indices, top_arg_spans) cls_repeat = (cls_projected.unsqueeze(dim=1).unsqueeze(dim=2).repeat( 1, num_trigs, num_args, 1)) pair_embeddings_list = [ trig_emb_tiled, arg_emb_tiled, distance_embeddings, cls_repeat ] pair_embeddings = torch.cat(pair_embeddings_list, dim=3) if trig_emb_extras: trig_extras_expanded = torch.cat(trig_emb_extras, dim=-1).unsqueeze(2) trig_extras_tiled = trig_extras_expanded.repeat(1, 1, num_args, 1) pair_embeddings = torch.cat([pair_embeddings, trig_extras_tiled], dim=3) if arg_emb_extras: arg_extras_expanded = torch.cat(arg_emb_extras, dim=-1).unsqueeze(1) arg_extras_tiled = arg_extras_expanded.repeat(1, num_trigs, 1, 1) pair_embeddings = torch.cat([pair_embeddings, arg_extras_tiled], dim=3) if self._shared_attention_context: attended_context = self._get_shared_attention_context( pair_embeddings, text_emb, text_mask) pair_embeddings = torch.cat([pair_embeddings, attended_context], dim=3) return pair_embeddings def _compute_distance_embeddings(self, top_trig_indices, top_arg_spans): top_trig_ixs = top_trig_indices.unsqueeze(2) arg_span_starts = top_arg_spans[:, :, 0].unsqueeze(1) arg_span_ends = top_arg_spans[:, :, 1].unsqueeze(1) dist_from_start = top_trig_ixs - arg_span_starts dist_from_end = top_trig_ixs - arg_span_ends # Distance from trigger to arg. dist = torch.min(dist_from_start.abs(), dist_from_end.abs()) # When the trigger is inside the arg span, also set the distance to zero. trigger_inside = (top_trig_ixs >= arg_span_starts) & (top_trig_ixs <= arg_span_ends) dist[trigger_inside] = 0 dist_buckets = util.bucket_values(dist, self._num_distance_buckets) dist_emb = self._distance_embedding(dist_buckets) trigger_before_feature = (top_trig_ixs < arg_span_starts).float().unsqueeze(-1) trigger_inside_feature = trigger_inside.float().unsqueeze(-1) res = torch.cat( [dist_emb, trigger_before_feature, trigger_inside_feature], dim=-1) return res def _get_shared_attention_context(self, pair_embeddings, text_emb, text_mask): batch_size, n_triggers, n_args, emb_dim = pair_embeddings.size() pair_emb_flat = pair_embeddings.view([batch_size, -1, emb_dim]) attn_unnorm = self._shared_attention_context_module( pair_emb_flat, text_emb) attn_weights = util.masked_softmax(attn_unnorm, text_mask) context = util.weighted_sum(text_emb, attn_weights) context = context.view(batch_size, n_triggers, n_args, -1) return context def _get_context(self, span_starts, span_ends, text_emb): """ Given span start and end (inclusive), get the context on either side. """ # The text_emb are already zero-padded on the right, which is correct. assert span_starts.size() == span_ends.size() batch_size, seq_length, emb_size = text_emb.size() num_candidates = span_starts.size(1) padding = torch.zeros(batch_size, self._context_window, emb_size, device=text_emb.device) # [batch_size, seq_length + 2 x context_window, emb_size] padded_emb = torch.cat([padding, text_emb, padding], dim=1) pad_batch = [] for batch_ix, (start_ixs, end_ixs) in enumerate(zip(span_starts, span_ends)): pad_entry = [] for start_ix, end_ix in zip(start_ixs, end_ixs): # The starts are inclusive, ends are exclusive. left_start = start_ix left_end = start_ix + self._context_window right_start = end_ix + self._context_window + 1 right_end = end_ix + 2 * self._context_window + 1 left_pad = padded_emb[batch_ix, left_start:left_end] right_pad = padded_emb[batch_ix, right_start:right_end] pad = torch.cat([left_pad, right_pad], dim=0).view(-1).unsqueeze(0) pad_entry.append(pad) pad_entry = torch.cat(pad_entry, dim=0).unsqueeze(0) pad_batch.append(pad_entry) pad_batch = torch.cat(pad_batch, dim=0) return pad_batch def _compute_argument_scores(self, pairwise_embeddings, top_trig_scores, top_arg_scores, top_arg_mask, prepend_zeros=True): batch_size = pairwise_embeddings.size(0) max_num_trigs = pairwise_embeddings.size(1) max_num_args = pairwise_embeddings.size(2) feature_dim = self._argument_feedforward.input_dim embeddings_flat = pairwise_embeddings.view(-1, feature_dim) arguments_projected_flat = self._argument_feedforward(embeddings_flat) argument_scores_flat = self._argument_scorer(arguments_projected_flat) argument_scores = argument_scores_flat.view(batch_size, max_num_trigs, max_num_args, -1) # Add the mention scores for each of the candidates. argument_scores += (top_trig_scores.unsqueeze(-1) + top_arg_scores.transpose(1, 2).unsqueeze(-1)) # Softmax correction to compare arguments. if self._softmax_correction: the_temp = torch.exp(self._softmax_log_temp) the_multiplier = torch.exp(self._softmax_log_multiplier) softmax_scores = util.masked_softmax(argument_scores / the_temp, mask=top_arg_mask, dim=2) argument_scores = argument_scores + the_multiplier * softmax_scores shape = [ argument_scores.size(0), argument_scores.size(1), argument_scores.size(2), 1 ] dummy_scores = argument_scores.new_zeros(*shape) if prepend_zeros: argument_scores = torch.cat([dummy_scores, argument_scores], -1) return argument_scores @staticmethod def _get_pruned_gold_arguments(argument_labels, top_trig_indices, top_arg_indices, top_trig_masks, top_arg_masks): """ Loop over each slice and get the labels for the spans from that slice. All labels are offset by 1 so that the "null" label gets class zero. This is the desired behavior for the softmax. Labels corresponding to masked relations keep the label -1, which the softmax loss ignores. """ arguments = [] zipped = zip(argument_labels, top_trig_indices, top_arg_indices, top_trig_masks.byte(), top_arg_masks.byte()) for sliced, trig_ixs, arg_ixs, trig_mask, arg_mask in zipped: entry = sliced[trig_ixs][:, arg_ixs].unsqueeze(0) mask_entry = trig_mask & arg_mask.transpose(0, 1).unsqueeze(0) entry[mask_entry] += 1 entry[~mask_entry] = -1 arguments.append(entry) return torch.cat(arguments, dim=0) def _get_trigger_loss(self, trigger_scores, trigger_labels, trigger_mask): trigger_scores_flat = trigger_scores.view(-1, self._n_trigger_labels) trigger_labels_flat = trigger_labels.view(-1) mask_flat = trigger_mask.view(-1).byte() loss = self._trigger_loss(trigger_scores_flat[mask_flat], trigger_labels_flat[mask_flat]) return loss def _get_argument_loss(self, argument_scores, argument_labels): """ Compute cross-entropy loss on argument labels. """ # Need to add one for the null class. scores_flat = argument_scores.view(-1, self._n_argument_labels + 1) # Need to add 1 so that the null label is 0, to line up with indices into prediction matrix. labels_flat = argument_labels.view(-1) # Compute cross-entropy loss. loss = self._argument_loss(scores_flat, labels_flat) return loss
def __init__(self, vocab: Vocabulary, make_feedforward: Callable, token_emb_dim: int, # Triggers are represented via token embeddings. span_emb_dim: int, # Arguments are represented via span embeddings. feature_size: int, trigger_spans_per_word: float, argument_spans_per_word: float, loss_weights: Dict[str, float], regularizer: Optional[RegularizerApplicator] = None) -> None: super(EventExtractor, self).__init__(vocab, regularizer) self._trigger_namespaces = [entry for entry in vocab.get_namespaces() if "trigger_labels" in entry] self._argument_namespaces = [entry for entry in vocab.get_namespaces() if "argument_labels" in entry] self._n_trigger_labels = {name: vocab.get_vocab_size(name) for name in self._trigger_namespaces} self._n_argument_labels = {name: vocab.get_vocab_size(name) for name in self._argument_namespaces} # Make sure the null trigger label is always 0. for namespace in self._trigger_namespaces: null_label = vocab.get_token_index("", namespace) assert null_label == 0 # If not, the dummy class won't correspond to the null label. # Create trigger scorers and pruners. self._trigger_scorers = torch.nn.ModuleDict() self._trigger_pruners = torch.nn.ModuleDict() for trigger_namespace in self._trigger_namespaces: # The trigger pruner. trigger_candidate_feedforward = make_feedforward(input_dim=token_emb_dim) self._trigger_pruners[trigger_namespace] = make_pruner(trigger_candidate_feedforward) # The trigger scorer. trigger_feedforward = make_feedforward(input_dim=token_emb_dim) self._trigger_scorers[namespace] = torch.nn.Sequential( TimeDistributed(trigger_feedforward), TimeDistributed(torch.nn.Linear(trigger_feedforward.get_output_dim(), self._n_trigger_labels[trigger_namespace] - 1))) # Creater argument scorers and pruners. self._mention_pruners = torch.nn.ModuleDict() self._argument_feedforwards = torch.nn.ModuleDict() self._argument_scorers = torch.nn.ModuleDict() for argument_namespace in self._argument_namespaces: # The argument pruner. mention_feedforward = make_feedforward(input_dim=span_emb_dim) self._mention_pruners[argument_namespace] = make_pruner(mention_feedforward) # The argument scorer. The `+ 2` is there because I include indicator features for # whether the trigger is before or inside the arg span. # TODO(dwadden) Here argument_feedforward_dim = token_emb_dim + span_emb_dim + feature_size + 2 argument_feedforward = make_feedforward(input_dim=argument_feedforward_dim) self._argument_feedforwards[argument_namespace] = argument_feedforward self._argument_scorers[argument_namespace] = torch.nn.Linear( argument_feedforward.get_output_dim(), self._n_argument_labels[argument_namespace]) # Weight on trigger labeling and argument labeling. self._loss_weights = loss_weights # Distance embeddings. self._num_distance_buckets = 10 # Just use 10 which is the default. self._distance_embedding = Embedding(embedding_dim=feature_size, num_embeddings=self._num_distance_buckets) self._trigger_spans_per_word = trigger_spans_per_word self._argument_spans_per_word = argument_spans_per_word # Metrics # TODO(dwadden) Need different metrics for different namespaces. self._metrics = EventMetrics() self._active_namespaces = {"trigger": None, "argument": None} # Trigger and argument loss. self._trigger_loss = torch.nn.CrossEntropyLoss(reduction="sum") self._argument_loss = torch.nn.CrossEntropyLoss(reduction="sum", ignore_index=-1)
def __init__( self, vocab: Vocabulary, make_feedforward: Callable, text_emb_dim: int, trigger_emb_dim: int, # Triggers are represented via span embeddings (but can have different width than arg spans). span_emb_dim: int, # Arguments are represented via span embeddings. feature_size: int, trigger_spans_per_word: float, argument_spans_per_word: float, loss_weights: Dict[str, float], context_window: int = 0, regularizer: Optional[RegularizerApplicator] = None) -> None: super(EventExtractor, self).__init__(vocab, regularizer) self._trigger_namespaces = [ entry for entry in vocab.get_namespaces() if "trigger_labels" in entry ] self._argument_namespaces = [ entry for entry in vocab.get_namespaces() if "argument_labels" in entry ] self._n_trigger_labels = { name: vocab.get_vocab_size(name) for name in self._trigger_namespaces } self._n_argument_labels = { name: vocab.get_vocab_size(name) for name in self._argument_namespaces } # Context window self._context_window = context_window # If greater than 0, concatenate context as features. context_window_dim = 4 * self._context_window * text_emb_dim # 2 (arg context + trig context) * 2 (left context + right context) * context_window + text_emb_size # Make sure the null trigger label is always 0. for namespace in self._trigger_namespaces: null_label = vocab.get_token_index("", namespace) assert null_label == 0 # If not, the dummy class won't correspond to the null label. # Create trigger scorers and pruners. self._trigger_scorers = torch.nn.ModuleDict() self._trigger_pruners = torch.nn.ModuleDict() for trigger_namespace in self._trigger_namespaces: # The trigger pruner. trigger_candidate_feedforward = make_feedforward( input_dim=trigger_emb_dim) self._trigger_pruners[trigger_namespace] = make_pruner( trigger_candidate_feedforward) # The trigger scorer. trigger_scorer_feedforward = make_feedforward( input_dim=trigger_emb_dim) self._trigger_scorers[namespace] = torch.nn.Sequential( TimeDistributed(trigger_scorer_feedforward), TimeDistributed( torch.nn.Linear( trigger_scorer_feedforward.get_output_dim(), self._n_trigger_labels[trigger_namespace] - 1))) # Create argument scorers and pruners. self._mention_pruners = torch.nn.ModuleDict() self._argument_feedforwards = torch.nn.ModuleDict() self._argument_scorers = torch.nn.ModuleDict() for argument_namespace in self._argument_namespaces: # The argument pruner. mention_feedforward = make_feedforward(input_dim=span_emb_dim) self._mention_pruners[argument_namespace] = make_pruner( mention_feedforward) # The argument scorer. The `+ 2` is there because I include indicator features for # whether the trigger is before or inside the arg span. # set argument feedforward argument_feedforward_dim = trigger_emb_dim + span_emb_dim + feature_size + 2 + context_window_dim # feature size + 2 = bucket distance embedding + 2 position features argument_feedforward = make_feedforward( input_dim=argument_feedforward_dim) self._argument_feedforwards[ argument_namespace] = argument_feedforward self._argument_scorers[argument_namespace] = torch.nn.Linear( argument_feedforward.get_output_dim(), self._n_argument_labels[argument_namespace]) # Weight on trigger labeling and argument labeling. self._loss_weights = loss_weights # Distance embeddings. self._num_distance_buckets = 10 # Just use 10 which is the default. self._distance_embedding = Embedding( embedding_dim=feature_size, num_embeddings=self._num_distance_buckets) self._trigger_spans_per_word = trigger_spans_per_word self._argument_spans_per_word = argument_spans_per_word # Metrics # Make a metric for each dataset (not each namespace). namespaces = self._trigger_namespaces + self._argument_namespaces datasets = set([x.split("__")[0] for x in namespaces]) self._metrics = {dataset: EventMetrics() for dataset in datasets} self._active_namespaces = {"trigger": None, "argument": None} self._active_dataset = None # Trigger and argument loss. self._trigger_loss = torch.nn.CrossEntropyLoss(reduction="sum") self._argument_loss = torch.nn.CrossEntropyLoss(reduction="sum", ignore_index=-1)
class EventExtractor(Model): """ Event extraction for DyGIE. """ def __init__(self, vocab: Vocabulary, make_feedforward: Callable, token_emb_dim: int, # Triggers are represented via token embeddings. span_emb_dim: int, # Arguments are represented via span embeddings. feature_size: int, trigger_spans_per_word: float, argument_spans_per_word: float, loss_weights: Dict[str, float], regularizer: Optional[RegularizerApplicator] = None) -> None: super(EventExtractor, self).__init__(vocab, regularizer) self._trigger_namespaces = [entry for entry in vocab.get_namespaces() if "trigger_labels" in entry] self._argument_namespaces = [entry for entry in vocab.get_namespaces() if "argument_labels" in entry] self._n_trigger_labels = {name: vocab.get_vocab_size(name) for name in self._trigger_namespaces} self._n_argument_labels = {name: vocab.get_vocab_size(name) for name in self._argument_namespaces} # Make sure the null trigger label is always 0. for namespace in self._trigger_namespaces: null_label = vocab.get_token_index("", namespace) assert null_label == 0 # If not, the dummy class won't correspond to the null label. # Create trigger scorers and pruners. self._trigger_scorers = torch.nn.ModuleDict() self._trigger_pruners = torch.nn.ModuleDict() for trigger_namespace in self._trigger_namespaces: # The trigger pruner. trigger_candidate_feedforward = make_feedforward(input_dim=token_emb_dim) self._trigger_pruners[trigger_namespace] = make_pruner(trigger_candidate_feedforward) # The trigger scorer. trigger_feedforward = make_feedforward(input_dim=token_emb_dim) self._trigger_scorers[namespace] = torch.nn.Sequential( TimeDistributed(trigger_feedforward), TimeDistributed(torch.nn.Linear(trigger_feedforward.get_output_dim(), self._n_trigger_labels[trigger_namespace] - 1))) # Creater argument scorers and pruners. self._mention_pruners = torch.nn.ModuleDict() self._argument_feedforwards = torch.nn.ModuleDict() self._argument_scorers = torch.nn.ModuleDict() for argument_namespace in self._argument_namespaces: # The argument pruner. mention_feedforward = make_feedforward(input_dim=span_emb_dim) self._mention_pruners[argument_namespace] = make_pruner(mention_feedforward) # The argument scorer. The `+ 2` is there because I include indicator features for # whether the trigger is before or inside the arg span. # TODO(dwadden) Here argument_feedforward_dim = token_emb_dim + span_emb_dim + feature_size + 2 argument_feedforward = make_feedforward(input_dim=argument_feedforward_dim) self._argument_feedforwards[argument_namespace] = argument_feedforward self._argument_scorers[argument_namespace] = torch.nn.Linear( argument_feedforward.get_output_dim(), self._n_argument_labels[argument_namespace]) # Weight on trigger labeling and argument labeling. self._loss_weights = loss_weights # Distance embeddings. self._num_distance_buckets = 10 # Just use 10 which is the default. self._distance_embedding = Embedding(embedding_dim=feature_size, num_embeddings=self._num_distance_buckets) self._trigger_spans_per_word = trigger_spans_per_word self._argument_spans_per_word = argument_spans_per_word # Metrics # TODO(dwadden) Need different metrics for different namespaces. self._metrics = EventMetrics() self._active_namespaces = {"trigger": None, "argument": None} # Trigger and argument loss. self._trigger_loss = torch.nn.CrossEntropyLoss(reduction="sum") self._argument_loss = torch.nn.CrossEntropyLoss(reduction="sum", ignore_index=-1) #################### @overrides def forward(self, # type: ignore trigger_mask, trigger_embeddings, spans, span_mask, span_embeddings, # TODO(dwadden) add type. sentence_lengths, trigger_labels, argument_labels, ner_labels, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: """ The trigger embeddings are just the contextualized token embeddings, and the trigger mask is the text mask. For the arguments, we consider all the spans. """ self._active_namespaces = {"trigger": f"{metadata.dataset}__trigger_labels", "argument": f"{metadata.dataset}__argument_labels"} # Compute trigger scores. trigger_scores = self._compute_trigger_scores( trigger_embeddings, trigger_mask) # Get trigger candidates for event argument labeling. num_trigs_to_keep = torch.floor( sentence_lengths.float() * self._trigger_spans_per_word).long() num_trigs_to_keep = torch.max(num_trigs_to_keep, torch.ones_like(num_trigs_to_keep)) num_trigs_to_keep = torch.min(num_trigs_to_keep, 15 * torch.ones_like(num_trigs_to_keep)) trigger_pruner = self._trigger_pruners[self._active_namespaces["trigger"]] (top_trig_embeddings, top_trig_mask, top_trig_indices, top_trig_scores, num_trigs_kept) = trigger_pruner( trigger_embeddings, trigger_mask, num_trigs_to_keep, trigger_scores) top_trig_mask = top_trig_mask.unsqueeze(-1) # Compute the number of argument spans to keep. num_arg_spans_to_keep = torch.floor( sentence_lengths.float() * self._argument_spans_per_word).long() num_arg_spans_to_keep = torch.max(num_arg_spans_to_keep, torch.ones_like(num_arg_spans_to_keep)) num_arg_spans_to_keep = torch.min(num_arg_spans_to_keep, 30 * torch.ones_like(num_arg_spans_to_keep)) # If we're using gold event arguments, include the gold labels. mention_pruner = self._mention_pruners[self._active_namespaces["argument"]] gold_labels = None (top_arg_embeddings, top_arg_mask, top_arg_indices, top_arg_scores, num_arg_spans_kept) = mention_pruner( span_embeddings, span_mask, num_arg_spans_to_keep, gold_labels) top_arg_mask = top_arg_mask.unsqueeze(-1) top_arg_spans = util.batched_index_select(spans, top_arg_indices) # Compute trigger / argument pair embeddings. trig_arg_embeddings = self._compute_trig_arg_embeddings( top_trig_embeddings, top_arg_embeddings, top_trig_indices, top_arg_spans) argument_scores = self._compute_argument_scores( trig_arg_embeddings, top_trig_scores, top_arg_scores, top_arg_mask) # Assemble inputs to do prediction. output_dict = {"top_trigger_indices": top_trig_indices, "top_argument_spans": top_arg_spans, "trigger_scores": trigger_scores, "argument_scores": argument_scores, "num_triggers_kept": num_trigs_kept, "num_argument_spans_kept": num_arg_spans_kept, "sentence_lengths": sentence_lengths} prediction_dicts, predictions = self.predict(output_dict, metadata) output_dict = {"predictions": predictions} # Evaluate loss and F1 if labels were provided. if trigger_labels is not None and argument_labels is not None: # Compute the loss for both triggers and arguments. trigger_loss = self._get_trigger_loss(trigger_scores, trigger_labels, trigger_mask) gold_arguments = self._get_pruned_gold_arguments( argument_labels, top_trig_indices, top_arg_indices, top_trig_mask, top_arg_mask) argument_loss = self._get_argument_loss(argument_scores, gold_arguments) # Compute F1. assert len(prediction_dicts) == len(metadata) # Make sure length of predictions is right. self._metrics(prediction_dicts, metadata) loss = (self._loss_weights["trigger"] * trigger_loss + self._loss_weights["arguments"] * argument_loss) output_dict["loss"] = loss return output_dict #################### # Embeddings def _compute_trig_arg_embeddings(self, top_trig_embeddings, top_arg_embeddings, top_trig_indices, top_arg_spans): """ Create trigger / argument pair embeddings, consisting of: - The embeddings of the trigger and argument pair. - Optionally, the embeddings of the trigger and argument labels. - Optionally, embeddings of the words surrounding the trigger and argument. """ num_trigs = top_trig_embeddings.size(1) num_args = top_arg_embeddings.size(1) trig_emb_expanded = top_trig_embeddings.unsqueeze(2) trig_emb_tiled = trig_emb_expanded.repeat(1, 1, num_args, 1) arg_emb_expanded = top_arg_embeddings.unsqueeze(1) arg_emb_tiled = arg_emb_expanded.repeat(1, num_trigs, 1, 1) distance_embeddings = self._compute_distance_embeddings(top_trig_indices, top_arg_spans) pair_embeddings_list = [trig_emb_tiled, arg_emb_tiled, distance_embeddings] pair_embeddings = torch.cat(pair_embeddings_list, dim=3) return pair_embeddings def _compute_distance_embeddings(self, top_trig_indices, top_arg_spans): top_trig_ixs = top_trig_indices.unsqueeze(2) arg_span_starts = top_arg_spans[:, :, 0].unsqueeze(1) arg_span_ends = top_arg_spans[:, :, 1].unsqueeze(1) dist_from_start = top_trig_ixs - arg_span_starts dist_from_end = top_trig_ixs - arg_span_ends # Distance from trigger to arg. dist = torch.min(dist_from_start.abs(), dist_from_end.abs()) # When the trigger is inside the arg span, also set the distance to zero. trigger_inside = (top_trig_ixs >= arg_span_starts) & (top_trig_ixs <= arg_span_ends) dist[trigger_inside] = 0 dist_buckets = util.bucket_values(dist, self._num_distance_buckets) dist_emb = self._distance_embedding(dist_buckets) trigger_before_feature = (top_trig_ixs < arg_span_starts).float().unsqueeze(-1) trigger_inside_feature = trigger_inside.float().unsqueeze(-1) res = torch.cat([dist_emb, trigger_before_feature, trigger_inside_feature], dim=-1) return res #################### # Scorers def _compute_trigger_scores(self, trigger_embeddings, trigger_mask): """ Compute trigger scores for all tokens. """ trigger_scorer = self._trigger_scorers[self._active_namespaces["trigger"]] trigger_scores = trigger_scorer(trigger_embeddings) # Give large negative scores to masked-out elements. mask = trigger_mask.unsqueeze(-1) trigger_scores = util.replace_masked_values(trigger_scores, mask.bool(), -1e20) dummy_dims = [trigger_scores.size(0), trigger_scores.size(1), 1] dummy_scores = trigger_scores.new_zeros(*dummy_dims) trigger_scores = torch.cat((dummy_scores, trigger_scores), -1) # Give large negative scores to the masked-out values. return trigger_scores def _compute_argument_scores(self, pairwise_embeddings, top_trig_scores, top_arg_scores, top_arg_mask, prepend_zeros=True): batch_size = pairwise_embeddings.size(0) max_num_trigs = pairwise_embeddings.size(1) max_num_args = pairwise_embeddings.size(2) argument_feedforward = self._argument_feedforwards[self._active_namespaces["argument"]] feature_dim = argument_feedforward.input_dim embeddings_flat = pairwise_embeddings.view(-1, feature_dim) arguments_projected_flat = argument_feedforward(embeddings_flat) argument_scorer = self._argument_scorers[self._active_namespaces["argument"]] argument_scores_flat = argument_scorer(arguments_projected_flat) argument_scores = argument_scores_flat.view(batch_size, max_num_trigs, max_num_args, -1) # Add the mention scores for each of the candidates. argument_scores += (top_trig_scores.unsqueeze(-1) + top_arg_scores.transpose(1, 2).unsqueeze(-1)) shape = [argument_scores.size(0), argument_scores.size(1), argument_scores.size(2), 1] dummy_scores = argument_scores.new_zeros(*shape) if prepend_zeros: argument_scores = torch.cat([dummy_scores, argument_scores], -1) return argument_scores #################### # Predictions / decoding. def predict(self, output_dict, document): """ Take the output and convert it into a list of dicts. Each entry is a sentence. Each key is a pair of span indices for that sentence, and each value is the relation label on that span pair. """ outputs = fields_to_batches({k: v.detach().cpu() for k, v in output_dict.items()}) prediction_dicts = [] predictions = [] # Collect predictions for each sentence in minibatch. for output, sentence in zip(outputs, document): decoded_trig = self._decode_trigger(output) decoded_args = self._decode_arguments(output, decoded_trig) predicted_events = self._assemble_predictions(decoded_trig, decoded_args, sentence) prediction_dicts.append({"trigger_dict": decoded_trig, "argument_dict": decoded_args}) predictions.append(predicted_events) return prediction_dicts, predictions def _decode_trigger(self, output): trigger_scores = output["trigger_scores"] predicted_scores_raw, predicted_triggers = trigger_scores.max(dim=1) softmax_scores = F.softmax(trigger_scores, dim=1) predicted_scores_softmax, _ = softmax_scores.max(dim=1) trigger_dict = {} # TODO(dwadden) Can speed this up with array ops. for i in range(output["sentence_lengths"]): trig_label = predicted_triggers[i].item() if trig_label > 0: predicted_label = self.vocab.get_token_from_index( trig_label, namespace=self._active_namespaces["trigger"]) trigger_dict[i] = (predicted_label, predicted_scores_raw[i].item(), predicted_scores_softmax[i].item()) return trigger_dict def _decode_arguments(self, output, decoded_trig): # TODO(dwadden) Vectorize. argument_dict = {} argument_scores = output["argument_scores"] predicted_scores_raw, predicted_arguments = argument_scores.max(dim=-1) # The null argument has label -1. predicted_arguments -= 1 softmax_scores = F.softmax(argument_scores, dim=-1) predicted_scores_softmax, _ = softmax_scores.max(dim=-1) for i, j in itertools.product(range(output["num_triggers_kept"]), range(output["num_argument_spans_kept"])): trig_ix = output["top_trigger_indices"][i].item() arg_span = tuple(output["top_argument_spans"][j].tolist()) arg_label = predicted_arguments[i, j].item() # Only include the argument if its putative trigger is predicted as a real trigger. if arg_label >= 0 and trig_ix in decoded_trig: arg_score_raw = predicted_scores_raw[i, j].item() arg_score_softmax = predicted_scores_softmax[i, j].item() label_name = self.vocab.get_token_from_index( arg_label, namespace=self._active_namespaces["argument"]) argument_dict[(trig_ix, arg_span)] = (label_name, arg_score_raw, arg_score_softmax) return argument_dict def _assemble_predictions(self, trigger_dict, argument_dict, sentence): events_json = [] for trigger_ix, trigger_label in trigger_dict.items(): this_event = [] this_event.append([trigger_ix] + list(trigger_label)) event_arguments = {k: v for k, v in argument_dict.items() if k[0] == trigger_ix} this_event_args = [] for k, v in event_arguments.items(): entry = list(k[1]) + list(v) this_event_args.append(entry) this_event_args = sorted(this_event_args, key=lambda entry: entry[0]) this_event.extend(this_event_args) events_json.append(this_event) events = document.PredictedEvents(events_json, sentence, sentence_offsets=True) return events #################### # Loss function and evaluation metrics. @staticmethod def _get_pruned_gold_arguments(argument_labels, top_trig_indices, top_arg_indices, top_trig_masks, top_arg_masks): """ Loop over each slice and get the labels for the spans from that slice. All labels are offset by 1 so that the "null" label gets class zero. This is the desired behavior for the softmax. Labels corresponding to masked relations keep the label -1, which the softmax loss ignores. """ arguments = [] zipped = zip(argument_labels, top_trig_indices, top_arg_indices, top_trig_masks.bool(), top_arg_masks.bool()) for sliced, trig_ixs, arg_ixs, trig_mask, arg_mask in zipped: entry = sliced[trig_ixs][:, arg_ixs].unsqueeze(0) mask_entry = trig_mask & arg_mask.transpose(0, 1).unsqueeze(0) entry[mask_entry] += 1 entry[~mask_entry] = -1 arguments.append(entry) return torch.cat(arguments, dim=0) def _get_trigger_loss(self, trigger_scores, trigger_labels, trigger_mask): n_trigger_labels = self._n_trigger_labels[self._active_namespaces["trigger"]] trigger_scores_flat = trigger_scores.view(-1, n_trigger_labels) trigger_labels_flat = trigger_labels.view(-1) mask_flat = trigger_mask.view(-1).bool() loss = self._trigger_loss(trigger_scores_flat[mask_flat], trigger_labels_flat[mask_flat]) return loss def _get_argument_loss(self, argument_scores, argument_labels): """ Compute cross-entropy loss on argument labels. """ n_argument_labels = self._n_argument_labels[self._active_namespaces["argument"]] # Need to add one for the null class. scores_flat = argument_scores.view(-1, n_argument_labels + 1) # Need to add 1 so that the null label is 0, to line up with indices into prediction matrix. labels_flat = argument_labels.view(-1) # Compute cross-entropy loss. loss = self._argument_loss(scores_flat, labels_flat) return loss @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: f1_metrics = self._metrics.get_metric(reset) res = {} res.update(f1_metrics) return res