def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, span_extractor: SpanExtractor, entity_embedder: TokenEmbedder, trigger_embedder: TokenEmbedder, hidden_dim: int, loss_weight: float = 1.0, trigger_gamma: float = None, role_gamma: float = None, triggers_namespace: str = 'trigger_labels', roles_namespace: str = 'arg_role_labels', initializer: InitializerApplicator = InitializerApplicator(), regularizer: RegularizerApplicator = None) -> None: super().__init__(vocab=vocab, regularizer=regularizer) self._triggers_namespace = triggers_namespace self._roles_namespace = roles_namespace self.num_trigger_classes = self.vocab.get_vocab_size( triggers_namespace) self.num_role_classes = self.vocab.get_vocab_size(roles_namespace) self.hidden_dim = hidden_dim self.loss_weight = loss_weight self.trigger_gamma = trigger_gamma self.role_gamma = role_gamma self.text_field_embedder = text_field_embedder self.encoder = encoder self.entity_embedder = entity_embedder self.trigger_embedder = trigger_embedder self.span_extractor = span_extractor self.trigger_projection = TimeDistributed( Linear(self.encoder.get_output_dim(), self.num_trigger_classes)) self.trigger_to_hidden = Linear( self.encoder.get_output_dim() + self.trigger_embedder.get_output_dim(), self.hidden_dim) self.entities_to_hidden = Linear( self.encoder.get_output_dim() + self.entity_embedder.get_output_dim(), self.hidden_dim) self.hidden_bias = Parameter(torch.Tensor(self.hidden_dim)) torch.nn.init.normal_(self.hidden_bias) self.hidden_to_roles = Linear(self.hidden_dim, self.num_role_classes) self.trigger_accuracy = CategoricalAccuracy() self.trigger_f1 = SpanBasedF1Measure( vocab, tag_namespace=triggers_namespace, label_encoding="BIO", ignore_classes=[NEGATIVE_TRIGGER_LABEL]) role_labels_to_idx = self.vocab.get_token_to_index_vocabulary( namespace=roles_namespace) evaluated_role_idxs = list(role_labels_to_idx.values()) evaluated_role_idxs.remove(role_labels_to_idx[NEGATIVE_ARGUMENT_LABEL]) self.role_accuracy = CategoricalAccuracy() self.role_f1 = MicroFBetaMeasure( average='micro', # Macro averaging in get_metrics labels=evaluated_role_idxs) initializer(self)
def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, span_extractor: SpanExtractor, entity_embedder: TokenEmbedder, hidden_dim: int, loss_weight: float = 1.0, trigger_gamma: float = None, role_gamma: float = None, positive_class_weight: float = 1.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: RegularizerApplicator = None) -> None: super().__init__(vocab=vocab, regularizer=regularizer) self.num_trigger_classes = len(SD4M_RELATION_TYPES) self.num_role_classes = len(ROLE_LABELS) self.hidden_dim = hidden_dim self.loss_weight = loss_weight self.trigger_gamma = trigger_gamma self.role_gamma = role_gamma self.text_field_embedder = text_field_embedder self.encoder = encoder self.entity_embedder = entity_embedder self.span_extractor = span_extractor self.trigger_projection = Linear(self.encoder.get_output_dim(), self.num_trigger_classes) self.trigger_to_hidden = Linear(self.encoder.get_output_dim(), self.hidden_dim) self.entities_to_hidden = Linear(self.encoder.get_output_dim(), self.hidden_dim) self.hidden_bias = Parameter(torch.Tensor(self.hidden_dim)) torch.nn.init.normal_(self.hidden_bias) self.hidden_to_roles = Linear(self.hidden_dim, self.num_role_classes) self.trigger_accuracy = CategoricalAccuracy() trigger_labels_to_idx = dict([ (label, idx) for idx, label in enumerate(SD4M_RELATION_TYPES) ]) evaluated_trigger_idxs = list(trigger_labels_to_idx.values()) evaluated_trigger_idxs.remove( trigger_labels_to_idx[NEGATIVE_TRIGGER_LABEL]) self.trigger_f1 = MicroFBetaMeasure( average='micro', # Macro averaging in get_metrics labels=evaluated_trigger_idxs) self.trigger_classes_f1 = MicroFBetaMeasure( average=None, labels=evaluated_trigger_idxs) role_labels_to_idx = dict([(label, idx) for idx, label in enumerate(ROLE_LABELS)]) evaluated_role_idxs = list(role_labels_to_idx.values()) evaluated_role_idxs.remove(role_labels_to_idx[NEGATIVE_ARGUMENT_LABEL]) self.role_accuracy = CategoricalAccuracy() self.role_f1 = MicroFBetaMeasure( average='micro', # Macro averaging in get_metrics labels=evaluated_role_idxs) self.role_classes_f1 = MicroFBetaMeasure(average=None, labels=evaluated_role_idxs) # Trigger class weighting as done in JMEE repo self.trigger_class_weights = torch.ones( len(SD4M_RELATION_TYPES)) * positive_class_weight self.trigger_class_weights[ trigger_labels_to_idx[NEGATIVE_TRIGGER_LABEL]] = 1.0 initializer(self)
class SnorkelEventxModel(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, span_extractor: SpanExtractor, entity_embedder: TokenEmbedder, hidden_dim: int, loss_weight: float = 1.0, trigger_gamma: float = None, role_gamma: float = None, positive_class_weight: float = 1.0, initializer: InitializerApplicator = InitializerApplicator(), regularizer: RegularizerApplicator = None) -> None: super().__init__(vocab=vocab, regularizer=regularizer) self.num_trigger_classes = len(SD4M_RELATION_TYPES) self.num_role_classes = len(ROLE_LABELS) self.hidden_dim = hidden_dim self.loss_weight = loss_weight self.trigger_gamma = trigger_gamma self.role_gamma = role_gamma self.text_field_embedder = text_field_embedder self.encoder = encoder self.entity_embedder = entity_embedder self.span_extractor = span_extractor self.trigger_projection = Linear(self.encoder.get_output_dim(), self.num_trigger_classes) self.trigger_to_hidden = Linear(self.encoder.get_output_dim(), self.hidden_dim) self.entities_to_hidden = Linear(self.encoder.get_output_dim(), self.hidden_dim) self.hidden_bias = Parameter(torch.Tensor(self.hidden_dim)) torch.nn.init.normal_(self.hidden_bias) self.hidden_to_roles = Linear(self.hidden_dim, self.num_role_classes) self.trigger_accuracy = CategoricalAccuracy() trigger_labels_to_idx = dict([ (label, idx) for idx, label in enumerate(SD4M_RELATION_TYPES) ]) evaluated_trigger_idxs = list(trigger_labels_to_idx.values()) evaluated_trigger_idxs.remove( trigger_labels_to_idx[NEGATIVE_TRIGGER_LABEL]) self.trigger_f1 = MicroFBetaMeasure( average='micro', # Macro averaging in get_metrics labels=evaluated_trigger_idxs) self.trigger_classes_f1 = MicroFBetaMeasure( average=None, labels=evaluated_trigger_idxs) role_labels_to_idx = dict([(label, idx) for idx, label in enumerate(ROLE_LABELS)]) evaluated_role_idxs = list(role_labels_to_idx.values()) evaluated_role_idxs.remove(role_labels_to_idx[NEGATIVE_ARGUMENT_LABEL]) self.role_accuracy = CategoricalAccuracy() self.role_f1 = MicroFBetaMeasure( average='micro', # Macro averaging in get_metrics labels=evaluated_role_idxs) self.role_classes_f1 = MicroFBetaMeasure(average=None, labels=evaluated_role_idxs) # Trigger class weighting as done in JMEE repo self.trigger_class_weights = torch.ones( len(SD4M_RELATION_TYPES)) * positive_class_weight self.trigger_class_weights[ trigger_labels_to_idx[NEGATIVE_TRIGGER_LABEL]] = 1.0 initializer(self) @overrides def forward( self, tokens: Dict[str, torch.LongTensor], entity_tags: torch.LongTensor, entity_spans: torch.LongTensor, trigger_spans: torch.LongTensor, trigger_labels: torch.LongTensor = None, arg_roles: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: embedded_tokens = self.text_field_embedder(tokens) text_mask = get_text_field_mask(tokens) embedded_entity_tags = self.entity_embedder(entity_tags) embedded_input = torch.cat([embedded_tokens, embedded_entity_tags], dim=-1) encoded_input = self.encoder(embedded_input, text_mask) ########################### # Trigger type prediction # ########################### # Extract the spans of the triggers trigger_spans_mask = (trigger_spans[:, :, 0] >= 0).long() encoded_triggers = self.span_extractor( sequence_tensor=encoded_input, span_indices=trigger_spans, sequence_mask=text_mask, span_indices_mask=trigger_spans_mask) # Pass the extracted triggers through a projection for classification trigger_logits = self.trigger_projection(encoded_triggers) # Add the trigger predictions to the output trigger_probabilities = F.softmax(trigger_logits, dim=-1) output_dict = { "trigger_logits": trigger_logits, "trigger_probabilities": trigger_probabilities } if trigger_labels is not None: # Compute loss and metrics using the given trigger labels # Trigger mask filters out padding (and abstained instances from snorkel labeling) trigger_mask = trigger_labels.sum(dim=2) > 0 # B x T # Trigger class probabilities to label decoded_trigger_labels = trigger_labels.argmax(dim=2) self.trigger_accuracy(trigger_logits, decoded_trigger_labels, trigger_mask.float()) self.trigger_f1(trigger_logits, decoded_trigger_labels, trigger_mask.float()) self.trigger_classes_f1(trigger_logits, decoded_trigger_labels, trigger_mask.float()) trigger_logits_t = trigger_logits.permute(0, 2, 1) trigger_loss = self._cross_entropy_loss( logits=trigger_logits_t, target=trigger_labels, target_mask=trigger_mask, weight=self.trigger_class_weights) output_dict["triggers_loss"] = trigger_loss output_dict["loss"] = trigger_loss ######################################## # Argument detection and role labeling # ######################################## # Extract the spans of the encoded entities entity_spans_mask = (entity_spans[:, :, 0] >= 0).long() encoded_entities = self.span_extractor( sequence_tensor=encoded_input, span_indices=entity_spans, sequence_mask=text_mask, span_indices_mask=entity_spans_mask) # Project both triggers and entities/args into a 'hidden' comparison space triggers_hidden = self.trigger_to_hidden(encoded_triggers) args_hidden = self.entities_to_hidden(encoded_entities) # Create the cross-product of triggers and args via broadcasting trigger = triggers_hidden.unsqueeze(2) # B x T x 1 x H args = args_hidden.unsqueeze(1) # B x 1 x E x H trigger_arg = trigger + args + self.hidden_bias # B x T x E x H # Pass through activation and projection for classification role_activations = F.relu(trigger_arg) role_logits = self.hidden_to_roles(role_activations) # B x T x E x R # Add the role predictions to the output role_probabilities = torch.softmax(role_logits, dim=-1) output_dict['role_logits'] = role_logits output_dict['role_probabilities'] = role_probabilities # Compute loss and metrics using the given role labels if arg_roles is not None: arg_roles = self._assert_target_shape(logits=role_logits, target=arg_roles, fill_value=0) target_mask = arg_roles.sum(dim=3) > 0 # B x T x E # Trigger class probabilities to label decoded_target = arg_roles.argmax(dim=3) self.role_accuracy(role_logits, decoded_target, target_mask.float()) self.role_f1(role_logits, decoded_target, target_mask.float()) self.role_classes_f1(role_logits, decoded_target, target_mask.float()) # Masked batch-wise cross entropy loss role_logits_t = role_logits.permute(0, 3, 1, 2) role_loss = self._cross_entropy_loss(logits=role_logits_t, target=arg_roles, target_mask=target_mask) output_dict['role_loss'] = role_loss output_dict['loss'] += self.loss_weight * role_loss # Append the original tokens for visualization if metadata is not None: output_dict["words"] = [x["words"] for x in metadata] # Append the trigger and entity spans to reconstruct the event after prediction output_dict['entity_spans'] = entity_spans output_dict['trigger_spans'] = trigger_spans return output_dict @overrides def decode( self, output_dict: Dict[str, Union[torch.Tensor, List]] ) -> Dict[str, torch.Tensor]: trigger_predictions = output_dict['trigger_probabilities'].cpu( ).data.numpy() trigger_labels = [[ SD4M_RELATION_TYPES[trigger_idx] for trigger_idx in example ] for example in np.argmax(trigger_predictions, axis=-1)] output_dict['trigger_labels'] = trigger_labels arg_role_predictions = output_dict['role_logits'].cpu().data.numpy() arg_role_labels = [[ [ROLE_LABELS[role_idx] for role_idx in event] for event in example ] for example in np.argmax(arg_role_predictions, axis=-1)] output_dict['role_labels'] = arg_role_labels events = [] for batch_idx in range(len(trigger_labels)): words = output_dict['words'][batch_idx] batch_events = [] for trigger_idx, trigger_label in enumerate( trigger_labels[batch_idx]): if trigger_label == NEGATIVE_TRIGGER_LABEL: continue trigger_span = output_dict['trigger_spans'][batch_idx][ trigger_idx] trigger_start = trigger_span[0].item() trigger_end = trigger_span[1].item() + 1 if trigger_start < 0: continue event = { 'event_type': trigger_label, 'trigger': { 'text': " ".join(words[trigger_start:trigger_end]), 'start': trigger_start, 'end': trigger_end }, 'arguments': [] } for entity_idx, role_label in enumerate( arg_role_labels[batch_idx][trigger_idx]): if role_label == NEGATIVE_ARGUMENT_LABEL: continue arg_span = output_dict['entity_spans'][batch_idx][ entity_idx] arg_start = arg_span[0].item() arg_end = arg_span[1].item() + 1 if arg_start < 0: continue argument = { 'text': " ".join(words[arg_start:arg_end]), 'start': arg_start, 'end': arg_end, 'role': role_label } event['arguments'].append(argument) if len(event['arguments']) > 0: batch_events.append(event) events.append(batch_events) output_dict['events'] = events return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: metrics_to_return = { 'trigger_acc': self.trigger_accuracy.get_metric(reset=reset), 'trigger_f1': self.trigger_f1.get_metric(reset=reset)['fscore'], 'role_acc': self.role_accuracy.get_metric(reset=reset), 'role_f1': self.role_f1.get_metric(reset=reset)['fscore'] } trigger_classes_f1 = self.trigger_classes_f1.get_metric( reset=reset)['fscore'] role_classes_f1 = self.role_classes_f1.get_metric( reset=reset)['fscore'] for trigger_class, class_f1 in zip(SD4M_RELATION_TYPES[:-1], trigger_classes_f1): metrics_to_return['_' + trigger_class + '_f1'] = class_f1 for role_class, class_f1 in zip(ROLE_LABELS[:-1], role_classes_f1): metrics_to_return['_' + role_class + '_f1'] = class_f1 return metrics_to_return @staticmethod def _assert_target_shape(logits, target, fill_value=0): """ Asserts that target tensors are always of the same size of logits. This is not always the case since some batches are not completely filled. """ expected_shape = logits.shape if target.shape == expected_shape: return target else: new_target = torch.full(size=expected_shape, fill_value=fill_value, dtype=target.dtype, device=target.device) batch_size, triggers_len, arguments_len, _ = target.shape new_target[:, :triggers_len, :arguments_len] = target return new_target @staticmethod def _cross_entropy_loss(logits, target, target_mask, weight=None) -> torch.Tensor: loss_unreduced = cross_entropy_with_probs(logits, target, reduction="none", weight=weight) masked_loss = loss_unreduced * target_mask batch_size = target.size(0) loss_per_batch = masked_loss.view(batch_size, -1).sum(dim=1) mask_per_batch = target_mask.view(batch_size, -1).sum() return (loss_per_batch / mask_per_batch).sum() / batch_size
class EventxModel(Model): def __init__(self, vocab: Vocabulary, text_field_embedder: TextFieldEmbedder, encoder: Seq2SeqEncoder, span_extractor: SpanExtractor, entity_embedder: TokenEmbedder, trigger_embedder: TokenEmbedder, hidden_dim: int, loss_weight: float = 1.0, trigger_gamma: float = None, role_gamma: float = None, triggers_namespace: str = 'trigger_labels', roles_namespace: str = 'arg_role_labels', initializer: InitializerApplicator = InitializerApplicator(), regularizer: RegularizerApplicator = None) -> None: super().__init__(vocab=vocab, regularizer=regularizer) self._triggers_namespace = triggers_namespace self._roles_namespace = roles_namespace self.num_trigger_classes = self.vocab.get_vocab_size( triggers_namespace) self.num_role_classes = self.vocab.get_vocab_size(roles_namespace) self.hidden_dim = hidden_dim self.loss_weight = loss_weight self.trigger_gamma = trigger_gamma self.role_gamma = role_gamma self.text_field_embedder = text_field_embedder self.encoder = encoder self.entity_embedder = entity_embedder self.trigger_embedder = trigger_embedder self.span_extractor = span_extractor self.trigger_projection = TimeDistributed( Linear(self.encoder.get_output_dim(), self.num_trigger_classes)) self.trigger_to_hidden = Linear( self.encoder.get_output_dim() + self.trigger_embedder.get_output_dim(), self.hidden_dim) self.entities_to_hidden = Linear( self.encoder.get_output_dim() + self.entity_embedder.get_output_dim(), self.hidden_dim) self.hidden_bias = Parameter(torch.Tensor(self.hidden_dim)) torch.nn.init.normal_(self.hidden_bias) self.hidden_to_roles = Linear(self.hidden_dim, self.num_role_classes) self.trigger_accuracy = CategoricalAccuracy() self.trigger_f1 = SpanBasedF1Measure( vocab, tag_namespace=triggers_namespace, label_encoding="BIO", ignore_classes=[NEGATIVE_TRIGGER_LABEL]) role_labels_to_idx = self.vocab.get_token_to_index_vocabulary( namespace=roles_namespace) evaluated_role_idxs = list(role_labels_to_idx.values()) evaluated_role_idxs.remove(role_labels_to_idx[NEGATIVE_ARGUMENT_LABEL]) self.role_accuracy = CategoricalAccuracy() self.role_f1 = MicroFBetaMeasure( average='micro', # Macro averaging in get_metrics labels=evaluated_role_idxs) initializer(self) @overrides def forward( self, tokens: Dict[str, torch.LongTensor], entity_labels: torch.LongTensor, entity_spans: torch.LongTensor, triggers: torch.LongTensor = None, arg_roles: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: embedded_tokens = self.text_field_embedder(tokens) text_mask = get_text_field_mask(tokens) encoded_tokens = self.encoder(embedded_tokens, text_mask) ########################### # Trigger type prediction # ########################### # Pass the extracted triggers through a projection for classification trigger_logits = self.trigger_projection(encoded_tokens) trigger_probabilities = F.softmax(trigger_logits, dim=-1) trigger_predictions = trigger_logits.argmax(dim=-1) output_dict = { "trigger_logits": trigger_logits, "trigger_probabilities": trigger_probabilities } if triggers is not None: self.trigger_accuracy(trigger_logits, triggers, text_mask.float()) self.trigger_f1(trigger_logits, triggers, text_mask.float()) loss = sequence_cross_entropy_with_logits(logits=trigger_logits, targets=triggers, weights=text_mask, gamma=self.trigger_gamma) output_dict["triggers_loss"] = loss output_dict["loss"] = loss ######################################## # Argument detection and role labeling # ######################################## # Extract the spans of the encoded entities # TODO check if squeeze(-1) is correct entity_spans_mask = (entity_spans[:, :, 0] >= 0).squeeze(-1).long() encoded_entities = self.span_extractor( sequence_tensor=encoded_tokens, span_indices=entity_spans, sequence_mask=text_mask, span_indices_mask=entity_spans_mask) # Project both triggers and entities/args into a 'hidden' comparison space entity_label_mask = (entity_labels != -1) entity_labels = entity_labels * entity_label_mask embedded_entity_labels = self.entity_embedder(entity_labels) embedded_trigger_labels = self.trigger_embedder(trigger_predictions) triggers_hidden = self.trigger_to_hidden( torch.cat([encoded_tokens, embedded_trigger_labels], dim=-1)) # B x L x H entities_hidden = self.entities_to_hidden( torch.cat([encoded_entities, embedded_entity_labels], dim=-1)) # B x E x H # Create the cross-product of triggers and args via broadcasting trigger = triggers_hidden.unsqueeze(2) # Shape: B x L x 1 x H args = entities_hidden.unsqueeze(1) # Shape: B x 1 x E x H trigger_arg = trigger + args + self.hidden_bias # B x L x E x H # Pass through activation and projection for classification role_activations = F.relu(trigger_arg) role_logits = self.hidden_to_roles(role_activations) # B x L x E x R # Add the role predictions to the output role_probabilities = torch.softmax(role_logits, dim=-1) output_dict['role_logits'] = role_logits output_dict['role_probabilities'] = role_probabilities # Compute loss and metrics using the given role labels if arg_roles is not None: arg_roles = self._assert_target_seq_len( seq_len=embedded_tokens.shape[1], target=arg_roles) target_mask = (arg_roles != -1) target = arg_roles * target_mask self.role_accuracy(role_logits, target, target_mask.float()) self.role_f1(role_logits, target, target_mask.float()) # Masked batch-wise cross entropy loss, optionally with focal-loss role_logits_t = role_logits.permute(0, 3, 1, 2) role_loss = cross_entropy_focal_loss(logits=role_logits_t, target=target, target_mask=target_mask, gamma=self.role_gamma) output_dict['role_loss'] = role_loss output_dict['loss'] += self.loss_weight * role_loss # Append the original tokens for visualization if metadata is not None: output_dict["words"] = [x["words"] for x in metadata] # Append the trigger and entity spans to reconstruct the event after prediction output_dict['entity_spans'] = entity_spans return output_dict @overrides def decode( self, output_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: trigger_probabilities = output_dict['trigger_probabilities'].cpu( ).data.numpy() trigger_predictions = np.argmax(trigger_probabilities, axis=-1) trigger_tags = [] for batch_idx in range(len(trigger_predictions)): # Based on number of words get rid of trigger padding in batches words = output_dict['words'][batch_idx] trigger_tags.append([ self.vocab.get_token_from_index( trigger_idx, namespace=self._triggers_namespace) for trigger_idx in trigger_predictions[batch_idx][:len(words)] ]) output_dict['trigger_tags'] = trigger_tags # Convert to trigger labels with inclusive spans: Tuple[str, Tuple[int, int]] trigger_labels = [ bio_tags_to_spans(example) for example in trigger_tags ] arg_role_probabilities = output_dict['role_logits'].cpu().data.numpy() arg_role_predictions = np.argmax(arg_role_probabilities, axis=-1) arg_role_labels = [] for batch_idx in range(len(arg_role_predictions)): # Based on number of words and entities get rid of arg role padding in batches words = output_dict['words'][batch_idx] entity_spans = [ entity_span for entity_span in output_dict['entity_spans'][batch_idx] if entity_span[0] > -1 ] arg_role_labels.append([[ self.vocab.get_token_from_index( role_idx, namespace=self._roles_namespace) for role_idx in event[:len(entity_spans)] ] for event in arg_role_predictions[batch_idx][:len(words)]]) output_dict['role_labels'] = arg_role_labels events = [] for batch_idx in range(len(trigger_labels)): words = output_dict['words'][batch_idx] batch_events = [] for trigger_idx, trigger_label_with_span in enumerate( trigger_labels[batch_idx]): trigger_label, trigger_span = trigger_label_with_span if trigger_label == NEGATIVE_TRIGGER_LABEL: continue trigger_start = trigger_span[0] trigger_end = trigger_span[1] + 1 event = { 'event_type': trigger_label, 'trigger': { 'text': " ".join(words[trigger_start:trigger_end]), 'start': trigger_start, 'end': trigger_end }, 'arguments': [] } # Group role labels by predicted trigger, sum and argmax to extract role label # in case of multi token trigger rel_arg_role_probs = arg_role_probabilities[batch_idx][trigger_start:trigger_end] \ .sum(axis=0) for entity_idx, role_probs in enumerate(rel_arg_role_probs): role_idx = role_probs.argmax() role_label = self.vocab.get_token_from_index( role_idx, namespace=self._roles_namespace) if role_label == NEGATIVE_ARGUMENT_LABEL: continue arg_span = output_dict['entity_spans'][batch_idx][ entity_idx] arg_start = arg_span[0].item() arg_end = arg_span[1].item() + 1 argument = { 'text': " ".join(words[arg_start:arg_end]), 'start': arg_start, 'end': arg_end, 'role': role_label } event['arguments'].append(argument) # if len(event['arguments']) > 0: # batch_events.append(event) batch_events.append(event) events.append(batch_events) output_dict['events'] = events return output_dict @overrides def get_metrics(self, reset: bool = False) -> Dict[str, float]: return { 'trigger_acc': self.trigger_accuracy.get_metric(reset=reset), 'trigger_f1': self.trigger_f1.get_metric(reset=reset)['f1-measure-overall'], 'role_acc': self.role_accuracy.get_metric(reset=reset), 'role_f1': self.role_f1.get_metric(reset=reset)['fscore'] } @staticmethod def _assert_target_seq_len(seq_len, target): """ In some batches the longest sentence does not include any entities. This results in a target tensor, which is not padded to the full seq length. """ batch_size, target_seq_len, num_spans = target.size() if seq_len == target_seq_len: return target else: missing_padding = seq_len - target_seq_len padding_size = (batch_size, missing_padding, num_spans) padding_tensor = torch.full(size=padding_size, fill_value=-1, dtype=target.dtype, device=target.device) return torch.cat([target, padding_tensor], dim=1)