예제 #1
0
파일: events.py 프로젝트: MSLars/mare
    def __init__(
            self,
            vocab: Vocabulary,
            make_feedforward: Callable,
            text_emb_dim: int,
            trigger_emb_dim:
        int,  # Triggers are represented via span embeddings (but can have different width than arg spans).
            span_emb_dim: int,  # Arguments are represented via span embeddings.
            feature_size: int,
            trigger_spans_per_word: float,
            argument_spans_per_word: float,
            loss_weights: Dict[str, float],
            context_window: int = 0,
            regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(EventExtractor, self).__init__(vocab, regularizer)

        self._trigger_namespaces = [
            entry for entry in vocab.get_namespaces()
            if "trigger_labels" in entry
        ]
        self._argument_namespaces = [
            entry for entry in vocab.get_namespaces()
            if "argument_labels" in entry
        ]

        self._n_trigger_labels = {
            name: vocab.get_vocab_size(name)
            for name in self._trigger_namespaces
        }
        self._n_argument_labels = {
            name: vocab.get_vocab_size(name)
            for name in self._argument_namespaces
        }

        # Context window
        self._context_window = context_window  # If greater than 0, concatenate context as features.
        context_window_dim = 4 * self._context_window * text_emb_dim
        # 2 (arg context + trig context) * 2 (left context + right context) * context_window + text_emb_size

        # Make sure the null trigger label is always 0.
        for namespace in self._trigger_namespaces:
            null_label = vocab.get_token_index("", namespace)
            assert null_label == 0  # If not, the dummy class won't correspond to the null label.

        # Create trigger scorers and pruners.
        self._trigger_scorers = torch.nn.ModuleDict()
        self._trigger_pruners = torch.nn.ModuleDict()
        for trigger_namespace in self._trigger_namespaces:
            # The trigger pruner.
            trigger_candidate_feedforward = make_feedforward(
                input_dim=trigger_emb_dim)
            self._trigger_pruners[trigger_namespace] = make_pruner(
                trigger_candidate_feedforward)
            # The trigger scorer.
            trigger_scorer_feedforward = make_feedforward(
                input_dim=trigger_emb_dim)
            self._trigger_scorers[namespace] = torch.nn.Sequential(
                TimeDistributed(trigger_scorer_feedforward),
                TimeDistributed(
                    torch.nn.Linear(
                        trigger_scorer_feedforward.get_output_dim(),
                        self._n_trigger_labels[trigger_namespace] - 1)))

        # Create argument scorers and pruners.
        self._mention_pruners = torch.nn.ModuleDict()
        self._argument_feedforwards = torch.nn.ModuleDict()
        self._argument_scorers = torch.nn.ModuleDict()
        for argument_namespace in self._argument_namespaces:
            # The argument pruner.
            mention_feedforward = make_feedforward(input_dim=span_emb_dim)
            self._mention_pruners[argument_namespace] = make_pruner(
                mention_feedforward)
            # The argument scorer. The `+ 2` is there because I include indicator features for
            # whether the trigger is before or inside the arg span.

            # set argument feedforward
            argument_feedforward_dim = trigger_emb_dim + span_emb_dim + feature_size + 2 + context_window_dim
            # feature size + 2 = bucket distance embedding + 2 position features
            argument_feedforward = make_feedforward(
                input_dim=argument_feedforward_dim)
            self._argument_feedforwards[
                argument_namespace] = argument_feedforward
            self._argument_scorers[argument_namespace] = torch.nn.Linear(
                argument_feedforward.get_output_dim(),
                self._n_argument_labels[argument_namespace])

        # Weight on trigger labeling and argument labeling.
        self._loss_weights = loss_weights

        # Distance embeddings.
        self._num_distance_buckets = 10  # Just use 10 which is the default.
        self._distance_embedding = Embedding(
            embedding_dim=feature_size,
            num_embeddings=self._num_distance_buckets)

        self._trigger_spans_per_word = trigger_spans_per_word
        self._argument_spans_per_word = argument_spans_per_word

        # Metrics
        # Make a metric for each dataset (not each namespace).
        namespaces = self._trigger_namespaces + self._argument_namespaces
        datasets = set([x.split("__")[0] for x in namespaces])
        self._metrics = {dataset: EventMetrics() for dataset in datasets}

        self._active_namespaces = {"trigger": None, "argument": None}
        self._active_dataset = None

        # Trigger and argument loss.
        self._trigger_loss = torch.nn.CrossEntropyLoss(reduction="sum")
        self._argument_loss = torch.nn.CrossEntropyLoss(reduction="sum",
                                                        ignore_index=-1)
예제 #2
0
파일: events.py 프로젝트: zxlzr/dygiepp
    def __init__(
            self,
            vocab: Vocabulary,
            trigger_feedforward: FeedForward,
            trigger_candidate_feedforward: FeedForward,
            mention_feedforward: FeedForward,  # Used if entity beam is off.
            argument_feedforward: FeedForward,
            context_attention: BilinearMatrixAttention,
            trigger_attention: Seq2SeqEncoder,
            span_prop: SpanProp,
            cls_projection: FeedForward,
            feature_size: int,
            trigger_spans_per_word: float,
            argument_spans_per_word: float,
            loss_weights,
            trigger_attention_context: bool,
            event_args_use_trigger_labels: bool,
            event_args_use_ner_labels: bool,
            event_args_label_emb: int,
            shared_attention_context: bool,
            label_embedding_method: str,
            event_args_label_predictor: str,
            event_args_gold_candidates:
        bool = False,  # If True, use gold argument candidates.
            context_window: int = 0,
            softmax_correction: bool = False,
            initializer: InitializerApplicator = InitializerApplicator(),
            positive_label_weight: float = 1.0,
            entity_beam: bool = False,
            regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(EventExtractor, self).__init__(vocab, regularizer)

        self._n_ner_labels = vocab.get_vocab_size("ner_labels")
        self._n_trigger_labels = vocab.get_vocab_size("trigger_labels")
        self._n_argument_labels = vocab.get_vocab_size("argument_labels")

        # Embeddings for trigger labels and ner labels, to be used by argument scorer.
        # These will be either one-hot encodings or learned embeddings, depending on "kind".
        self._ner_label_emb = make_embedder(kind=label_embedding_method,
                                            num_embeddings=self._n_ner_labels,
                                            embedding_dim=event_args_label_emb)
        self._trigger_label_emb = make_embedder(
            kind=label_embedding_method,
            num_embeddings=self._n_trigger_labels,
            embedding_dim=event_args_label_emb)
        self._label_embedding_method = label_embedding_method

        # Weight on trigger labeling and argument labeling.
        self._loss_weights = loss_weights.as_dict()

        # Trigger candidate scorer.
        null_label = vocab.get_token_index("", "trigger_labels")
        assert null_label == 0  # If not, the dummy class won't correspond to the null label.

        self._trigger_scorer = torch.nn.Sequential(
            TimeDistributed(trigger_feedforward),
            TimeDistributed(
                torch.nn.Linear(trigger_feedforward.get_output_dim(),
                                self._n_trigger_labels - 1)))

        self._trigger_attention_context = trigger_attention_context
        if self._trigger_attention_context:
            self._trigger_attention = trigger_attention

        # Make pruners. If `entity_beam` is true, use NER and trigger scorers to construct the beam
        # and only keep candidates that the model predicts are actual entities or triggers.
        self._mention_pruner = make_pruner(
            mention_feedforward,
            entity_beam=entity_beam,
            gold_beam=event_args_gold_candidates)
        self._trigger_pruner = make_pruner(trigger_candidate_feedforward,
                                           entity_beam=entity_beam,
                                           gold_beam=False)

        # Argument scorer.
        self._event_args_use_trigger_labels = event_args_use_trigger_labels  # If True, use trigger labels.
        self._event_args_use_ner_labels = event_args_use_ner_labels  # If True, use ner labels to predict args.
        assert event_args_label_predictor in [
            "hard", "softmax", "gold"
        ]  # Method for predicting labels at test time.
        self._event_args_label_predictor = event_args_label_predictor
        self._event_args_gold_candidates = event_args_gold_candidates
        # If set to True, then construct a context vector from a bilinear attention over the trigger
        # / argument pair embeddings and the text.
        self._context_window = context_window  # If greater than 0, concatenate context as features.
        self._argument_feedforward = argument_feedforward
        self._argument_scorer = torch.nn.Linear(
            argument_feedforward.get_output_dim(), self._n_argument_labels)

        # Distance embeddings.
        self._num_distance_buckets = 10  # Just use 10 which is the default.
        self._distance_embedding = Embedding(self._num_distance_buckets,
                                             feature_size)

        # Class token projection.
        self._cls_projection = cls_projection
        self._cls_n_triggers = torch.nn.Linear(
            self._cls_projection.get_output_dim(), 5)
        self._cls_event_types = torch.nn.Linear(
            self._cls_projection.get_output_dim(), self._n_trigger_labels - 1)

        self._trigger_spans_per_word = trigger_spans_per_word
        self._argument_spans_per_word = argument_spans_per_word

        # Context attention for event argument scorer.
        self._shared_attention_context = shared_attention_context
        if self._shared_attention_context:
            self._shared_attention_context_module = context_attention

        # Span propagation object.
        # TODO(dwadden) initialize with `from_params` instead if this ends up working.
        self._span_prop = span_prop
        self._span_prop._trig_arg_embedder = self._compute_trig_arg_embeddings
        self._span_prop._argument_scorer = self._compute_argument_scores

        # Softmax correction parameters.
        self._softmax_correction = softmax_correction
        self._softmax_log_temp = torch.nn.Parameter(
            torch.zeros([1, 1, 1, self._n_argument_labels]))
        self._softmax_log_multiplier = torch.nn.Parameter(
            torch.zeros([1, 1, 1, self._n_argument_labels]))

        # TODO(dwadden) Add metrics.
        self._metrics = EventMetrics()
        self._argument_stats = ArgumentStats()

        self._trigger_loss = torch.nn.CrossEntropyLoss(reduction="sum")
        # TODO(dwadden) add loss weights.
        self._argument_loss = torch.nn.CrossEntropyLoss(reduction="sum",
                                                        ignore_index=-1)
        initializer(self)
예제 #3
0
    def __init__(self,
                 vocab: Vocabulary,
                 make_feedforward: Callable,
                 token_emb_dim: int,   # Triggers are represented via token embeddings.
                 span_emb_dim: int,    # Arguments are represented via span embeddings.
                 feature_size: int,
                 trigger_spans_per_word: float,
                 argument_spans_per_word: float,
                 loss_weights: Dict[str, float],
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(EventExtractor, self).__init__(vocab, regularizer)

        self._trigger_namespaces = [entry for entry in vocab.get_namespaces()
                                    if "trigger_labels" in entry]
        self._argument_namespaces = [entry for entry in vocab.get_namespaces()
                                     if "argument_labels" in entry]

        self._n_trigger_labels = {name: vocab.get_vocab_size(name)
                                  for name in self._trigger_namespaces}
        self._n_argument_labels = {name: vocab.get_vocab_size(name)
                                   for name in self._argument_namespaces}

        # Make sure the null trigger label is always 0.
        for namespace in self._trigger_namespaces:
            null_label = vocab.get_token_index("", namespace)
            assert null_label == 0  # If not, the dummy class won't correspond to the null label.

        # Create trigger scorers and pruners.
        self._trigger_scorers = torch.nn.ModuleDict()
        self._trigger_pruners = torch.nn.ModuleDict()
        for trigger_namespace in self._trigger_namespaces:
            # The trigger pruner.
            trigger_candidate_feedforward = make_feedforward(input_dim=token_emb_dim)
            self._trigger_pruners[trigger_namespace] = make_pruner(trigger_candidate_feedforward)
            # The trigger scorer.
            trigger_feedforward = make_feedforward(input_dim=token_emb_dim)
            self._trigger_scorers[namespace] = torch.nn.Sequential(
                TimeDistributed(trigger_feedforward),
                TimeDistributed(torch.nn.Linear(trigger_feedforward.get_output_dim(),
                                                self._n_trigger_labels[trigger_namespace] - 1)))

        # Creater argument scorers and pruners.
        self._mention_pruners = torch.nn.ModuleDict()
        self._argument_feedforwards = torch.nn.ModuleDict()
        self._argument_scorers = torch.nn.ModuleDict()
        for argument_namespace in self._argument_namespaces:
            # The argument pruner.
            mention_feedforward = make_feedforward(input_dim=span_emb_dim)
            self._mention_pruners[argument_namespace] = make_pruner(mention_feedforward)
            # The argument scorer. The `+ 2` is there because I include indicator features for
            # whether the trigger is before or inside the arg span.

            # TODO(dwadden) Here
            argument_feedforward_dim = token_emb_dim + span_emb_dim + feature_size + 2
            argument_feedforward = make_feedforward(input_dim=argument_feedforward_dim)
            self._argument_feedforwards[argument_namespace] = argument_feedforward
            self._argument_scorers[argument_namespace] = torch.nn.Linear(
                argument_feedforward.get_output_dim(), self._n_argument_labels[argument_namespace])

        # Weight on trigger labeling and argument labeling.
        self._loss_weights = loss_weights

        # Distance embeddings.
        self._num_distance_buckets = 10  # Just use 10 which is the default.
        self._distance_embedding = Embedding(embedding_dim=feature_size,
                                             num_embeddings=self._num_distance_buckets)

        self._trigger_spans_per_word = trigger_spans_per_word
        self._argument_spans_per_word = argument_spans_per_word

        # Metrics
        # TODO(dwadden) Need different metrics for different namespaces.
        self._metrics = EventMetrics()

        self._active_namespaces = {"trigger": None, "argument": None}

        # Trigger and argument loss.
        self._trigger_loss = torch.nn.CrossEntropyLoss(reduction="sum")
        self._argument_loss = torch.nn.CrossEntropyLoss(reduction="sum", ignore_index=-1)