Exemplo n.º 1
0
    def forward(
            self,
            tokens: Dict[str, torch.LongTensor],
            entity_tags: torch.LongTensor,
            entity_spans: torch.LongTensor,
            trigger_spans: torch.LongTensor,
            trigger_labels: torch.LongTensor = None,
            arg_roles: torch.LongTensor = None,
            metadata: List[Dict[str, Any]] = None) -> Dict[str, torch.Tensor]:
        embedded_tokens = self.text_field_embedder(tokens)
        text_mask = get_text_field_mask(tokens)
        embedded_entity_tags = self.entity_embedder(entity_tags)
        embedded_input = torch.cat([embedded_tokens, embedded_entity_tags],
                                   dim=-1)

        encoded_input = self.encoder(embedded_input, text_mask)

        ###########################
        # Trigger type prediction #
        ###########################

        # Extract the spans of the triggers
        trigger_spans_mask = (trigger_spans[:, :, 0] >= 0).long()
        encoded_triggers = self.span_extractor(
            sequence_tensor=encoded_input,
            span_indices=trigger_spans,
            sequence_mask=text_mask,
            span_indices_mask=trigger_spans_mask)

        # Pass the extracted triggers through a projection for classification
        trigger_logits = self.trigger_projection(encoded_triggers)

        # Add the trigger predictions to the output
        trigger_probabilities = F.softmax(trigger_logits, dim=-1)
        output_dict = {
            "trigger_logits": trigger_logits,
            "trigger_probabilities": trigger_probabilities
        }

        if trigger_labels is not None:
            # Compute loss and metrics using the given trigger labels
            # Trigger mask filters out padding (and abstained instances from snorkel labeling)
            trigger_mask = trigger_labels.sum(dim=2) > 0  # B x T
            # Trigger class probabilities to label
            decoded_trigger_labels = trigger_labels.argmax(dim=2)

            self.trigger_accuracy(trigger_logits, decoded_trigger_labels,
                                  trigger_mask.float())
            self.trigger_f1(trigger_logits, decoded_trigger_labels,
                            trigger_mask.float())
            self.trigger_classes_f1(trigger_logits, decoded_trigger_labels,
                                    trigger_mask.float())

            trigger_logits_t = trigger_logits.permute(0, 2, 1)
            trigger_loss = self._cross_entropy_loss(
                logits=trigger_logits_t,
                target=trigger_labels,
                target_mask=trigger_mask,
                weight=self.trigger_class_weights)

            output_dict["triggers_loss"] = trigger_loss
            output_dict["loss"] = trigger_loss

        ########################################
        # Argument detection and role labeling #
        ########################################

        # Extract the spans of the encoded entities
        entity_spans_mask = (entity_spans[:, :, 0] >= 0).long()
        encoded_entities = self.span_extractor(
            sequence_tensor=encoded_input,
            span_indices=entity_spans,
            sequence_mask=text_mask,
            span_indices_mask=entity_spans_mask)

        # Project both triggers and entities/args into a 'hidden' comparison space
        triggers_hidden = self.trigger_to_hidden(encoded_triggers)
        args_hidden = self.entities_to_hidden(encoded_entities)

        # Create the cross-product of triggers and args via broadcasting
        trigger = triggers_hidden.unsqueeze(2)  # B x T x 1 x H
        args = args_hidden.unsqueeze(1)  # B x 1 x E x H
        trigger_arg = trigger + args + self.hidden_bias  # B x T x E x H

        # Pass through activation and projection for classification
        role_activations = F.relu(trigger_arg)
        role_logits = self.hidden_to_roles(role_activations)  # B x T x E x R

        # Add the role predictions to the output
        role_probabilities = torch.softmax(role_logits, dim=-1)
        output_dict['role_logits'] = role_logits
        output_dict['role_probabilities'] = role_probabilities

        # Compute loss and metrics using the given role labels
        if arg_roles is not None:
            arg_roles = self._assert_target_shape(logits=role_logits,
                                                  target=arg_roles,
                                                  fill_value=0)

            target_mask = arg_roles.sum(dim=3) > 0  # B x T x E
            # Trigger class probabilities to label
            decoded_target = arg_roles.argmax(dim=3)

            self.role_accuracy(role_logits, decoded_target,
                               target_mask.float())
            self.role_f1(role_logits, decoded_target, target_mask.float())
            self.role_classes_f1(role_logits, decoded_target,
                                 target_mask.float())

            # Masked batch-wise cross entropy loss
            role_logits_t = role_logits.permute(0, 3, 1, 2)
            role_loss = self._cross_entropy_loss(logits=role_logits_t,
                                                 target=arg_roles,
                                                 target_mask=target_mask)

            output_dict['role_loss'] = role_loss
            output_dict['loss'] += self.loss_weight * role_loss

        # Append the original tokens for visualization
        if metadata is not None:
            output_dict["words"] = [x["words"] for x in metadata]

        # Append the trigger and entity spans to reconstruct the event after prediction
        output_dict['entity_spans'] = entity_spans
        output_dict['trigger_spans'] = trigger_spans

        return output_dict