def forward( self, tokens: Dict[str, torch.LongTensor], entity_tags: torch.LongTensor, entity_spans: torch.LongTensor, trigger_spans: torch.LongTensor, trigger_labels: torch.LongTensor = None, arg_roles: torch.LongTensor = None, metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]: embedded_tokens = self.text_field_embedder(tokens) text_mask = get_text_field_mask(tokens) embedded_entity_tags = self.entity_embedder(entity_tags) embedded_input = torch.cat([embedded_tokens, embedded_entity_tags], dim=-1) encoded_input = self.encoder(embedded_input, text_mask) ########################### # Trigger type prediction # ########################### # Extract the spans of the triggers trigger_spans_mask = (trigger_spans[:, :, 0] >= 0).long() encoded_triggers = self.span_extractor( sequence_tensor=encoded_input, span_indices=trigger_spans, sequence_mask=text_mask, span_indices_mask=trigger_spans_mask) # Pass the extracted triggers through a projection for classification trigger_logits = self.trigger_projection(encoded_triggers) # Add the trigger predictions to the output trigger_probabilities = F.softmax(trigger_logits, dim=-1) output_dict = { "trigger_logits": trigger_logits, "trigger_probabilities": trigger_probabilities } if trigger_labels is not None: # Compute loss and metrics using the given trigger labels # Trigger mask filters out padding (and abstained instances from snorkel labeling) trigger_mask = trigger_labels.sum(dim=2) > 0 # B x T # Trigger class probabilities to label decoded_trigger_labels = trigger_labels.argmax(dim=2) self.trigger_accuracy(trigger_logits, decoded_trigger_labels, trigger_mask.float()) self.trigger_f1(trigger_logits, decoded_trigger_labels, trigger_mask.float()) self.trigger_classes_f1(trigger_logits, decoded_trigger_labels, trigger_mask.float()) trigger_logits_t = trigger_logits.permute(0, 2, 1) trigger_loss = self._cross_entropy_loss( logits=trigger_logits_t, target=trigger_labels, target_mask=trigger_mask, weight=self.trigger_class_weights) output_dict["triggers_loss"] = trigger_loss output_dict["loss"] = trigger_loss ######################################## # Argument detection and role labeling # ######################################## # Extract the spans of the encoded entities entity_spans_mask = (entity_spans[:, :, 0] >= 0).long() encoded_entities = self.span_extractor( sequence_tensor=encoded_input, span_indices=entity_spans, sequence_mask=text_mask, span_indices_mask=entity_spans_mask) # Project both triggers and entities/args into a 'hidden' comparison space triggers_hidden = self.trigger_to_hidden(encoded_triggers) args_hidden = self.entities_to_hidden(encoded_entities) # Create the cross-product of triggers and args via broadcasting trigger = triggers_hidden.unsqueeze(2) # B x T x 1 x H args = args_hidden.unsqueeze(1) # B x 1 x E x H trigger_arg = trigger + args + self.hidden_bias # B x T x E x H # Pass through activation and projection for classification role_activations = F.relu(trigger_arg) role_logits = self.hidden_to_roles(role_activations) # B x T x E x R # Add the role predictions to the output role_probabilities = torch.softmax(role_logits, dim=-1) output_dict['role_logits'] = role_logits output_dict['role_probabilities'] = role_probabilities # Compute loss and metrics using the given role labels if arg_roles is not None: arg_roles = self._assert_target_shape(logits=role_logits, target=arg_roles, fill_value=0) target_mask = arg_roles.sum(dim=3) > 0 # B x T x E # Trigger class probabilities to label decoded_target = arg_roles.argmax(dim=3) self.role_accuracy(role_logits, decoded_target, target_mask.float()) self.role_f1(role_logits, decoded_target, target_mask.float()) self.role_classes_f1(role_logits, decoded_target, target_mask.float()) # Masked batch-wise cross entropy loss role_logits_t = role_logits.permute(0, 3, 1, 2) role_loss = self._cross_entropy_loss(logits=role_logits_t, target=arg_roles, target_mask=target_mask) output_dict['role_loss'] = role_loss output_dict['loss'] += self.loss_weight * role_loss # Append the original tokens for visualization if metadata is not None: output_dict["words"] = [x["words"] for x in metadata] # Append the trigger and entity spans to reconstruct the event after prediction output_dict['entity_spans'] = entity_spans output_dict['trigger_spans'] = trigger_spans return output_dict