def test_widths_are_embedded_correctly(self):
        input_dim = 7
        max_span_width = 5
        span_width_embedding_dim = 3
        output_dim = input_dim + span_width_embedding_dim
        extractor = SelfAttentiveSpanExtractor(
            input_dim=input_dim,
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=span_width_embedding_dim,
        )
        assert extractor.get_output_dim() == output_dim
        assert extractor.get_input_dim() == input_dim

        sequence_tensor = torch.randn([2, max_span_width, input_dim])
        indices = torch.LongTensor([[[1, 3], [0, 4], [0, 0]],
                                    [[0, 2], [1, 4],
                                     [2, 2]]])  # smaller span tests masking.
        span_representations = extractor(sequence_tensor, indices)
        assert list(span_representations.size()) == [2, 3, output_dim]

        width_embeddings = extractor._span_width_embedding.weight.data.numpy()
        widths_minus_one = indices[..., 1] - indices[..., 0]
        for element in range(indices.size(0)):
            for span in range(indices.size(1)):
                width = widths_minus_one[element, span].item()
                width_embedding = span_representations[element, span,
                                                       input_dim:]
                numpy.testing.assert_array_almost_equal(
                    width_embedding.data.numpy(), width_embeddings[width])
예제 #2
0
    def test_attention_is_normalised_correctly(self):
        input_dim = 7
        sequence_tensor = torch.randn([2, 5, input_dim])
        extractor = SelfAttentiveSpanExtractor(input_dim=input_dim)
        assert extractor.get_output_dim() == input_dim
        assert extractor.get_input_dim() == input_dim

        # In order to test the attention, we'll make the weight which computes the logits
        # zero, so the attention distribution is uniform over the sentence. This lets
        # us check that the computed spans are just the averages of their representations.
        extractor._global_attention._module.weight.data.fill_(0.0)
        extractor._global_attention._module.bias.data.fill_(0.0)

        indices = torch.LongTensor(
            [[[1, 3], [2, 4]], [[0, 2], [3, 4]]]
        )  # smaller span tests masking.
        span_representations = extractor(sequence_tensor, indices)
        assert list(span_representations.size()) == [2, 2, input_dim]

        # First element in the batch.
        batch_element = 0
        spans = span_representations[batch_element]
        # First span.
        mean_embeddings = sequence_tensor[batch_element, 1:4, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy())
        # Second span.
        mean_embeddings = sequence_tensor[batch_element, 2:5, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), mean_embeddings.data.numpy())
        # Now the second element in the batch.
        batch_element = 1
        spans = span_representations[batch_element]
        # First span.
        mean_embeddings = sequence_tensor[batch_element, 0:3, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy())
        # Second span.
        mean_embeddings = sequence_tensor[batch_element, 3:5, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), mean_embeddings.data.numpy())

        # Now test the case in which we have some masked spans in our indices.
        indices_mask = torch.BoolTensor([[True, True], [True, False]])
        span_representations = extractor(sequence_tensor, indices, span_indices_mask=indices_mask)

        # First element in the batch.
        batch_element = 0
        spans = span_representations[batch_element]
        # First span.
        mean_embeddings = sequence_tensor[batch_element, 1:4, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy())
        # Second span.
        mean_embeddings = sequence_tensor[batch_element, 2:5, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), mean_embeddings.data.numpy())
        # Now the second element in the batch.
        batch_element = 1
        spans = span_representations[batch_element]
        # First span.
        mean_embeddings = sequence_tensor[batch_element, 0:3, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy())
        # Second span was masked, so should be completely zero.
        numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), numpy.zeros([input_dim]))
예제 #3
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 context_layer: Seq2SeqEncoder,
                 mention_feedforward: FeedForward,
                 antecedent_feedforward: FeedForward,
                 feature_size: int,
                 max_span_width: int,
                 spans_per_word: float,
                 max_antecedents: int,
                 lexical_dropout: float = 0.2,
                 context_layer_back: Seq2SeqEncoder = None,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(CoreferenceResolver, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._context_layer = context_layer
        self._context_layer_back = context_layer_back
        self._antecedent_feedforward = TimeDistributed(antecedent_feedforward)
        feedforward_scorer = torch.nn.Sequential(
            TimeDistributed(mention_feedforward),
            TimeDistributed(
                torch.nn.Linear(mention_feedforward.get_output_dim(), 1)))
        self._mention_pruner = SpanPruner(feedforward_scorer)
        self._antecedent_scorer = TimeDistributed(
            torch.nn.Linear(antecedent_feedforward.get_output_dim(), 1))
        # TODO check the output dim when two context layers are passed through
        self._endpoint_span_extractor = EndpointSpanExtractor(
            context_layer.get_output_dim(),
            combination="x,y",
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False)
        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=text_field_embedder.get_output_dim())

        # 10 possible distance buckets.
        self._num_distance_buckets = 10
        self._distance_embedding = Embedding(self._num_distance_buckets,
                                             feature_size)
        self._speaker_embedding = Embedding(2, feature_size)
        self.genres = {
            g: i
            for i, g in enumerate(['bc', 'bn', 'mz', 'nw', 'pt', 'tc', 'wb'])
        }
        self._genre_embedding = Embedding(len(self.genres), feature_size)

        self._max_span_width = max_span_width
        self._spans_per_word = spans_per_word
        self._max_antecedents = max_antecedents

        self._mention_recall = MentionRecall()
        self._conll_coref_scores = ConllCorefScores()
        if lexical_dropout > 0:
            self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
        else:
            self._lexical_dropout = lambda x: x
        self._feature_dropout = torch.nn.Dropout(0.2)
        initializer(self)
예제 #4
0
    def __init__(
        self,
        use_citation_graph_embeddings: bool,
        citation_embedding_file: str,
        doc_to_idx_mapping_file: str,
        finetune_embedding: bool,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        context_layer: Seq2SeqEncoder,
        modules: Params,
        loss_weights: Dict[str, int],
        lexical_dropout: float = 0.2,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
        display_metrics: List[str] = None,
    ) -> None:
        super(SalientOnlyModel, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._context_layer = context_layer
        self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)

        if use_citation_graph_embeddings:
            if citation_embedding_file == "" or doc_to_idx_mapping_file == "":
                raise ValueError(
                    "Must supply citation embedding files to use graph embedding features"
                )
            self._document_embedding = initialize_graph_embeddings(
                citation_embedding_file, finetune_embedding=finetune_embedding)
            self._doc_to_idx_mapping = json.load(open(doc_to_idx_mapping_file))
        else:
            self._document_embedding = None
            self._doc_to_idx_mapping = None

        modules = Params(modules)

        self._saliency_classifier = SpanClassifier.from_params(
            vocab=vocab,
            document_embedding=self._document_embedding,
            doc_to_idx_mapping=self._doc_to_idx_mapping,
            params=modules.pop("saliency_classifier"))
        self._endpoint_span_extractor = EndpointSpanExtractor(
            context_layer.get_output_dim(), combination="x,y")
        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=context_layer.get_output_dim())

        for k in loss_weights:
            loss_weights[k] = float(loss_weights[k])

        self._loss_weights = loss_weights
        self._permanent_loss_weights = copy.deepcopy(self._loss_weights)

        self._display_metrics = display_metrics
        self._multi_task_loss_metrics = {k: Average() for k in ["saliency"]}

        self.training_mode = True
        self.prediction_mode = False

        initializer(self)
예제 #5
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        context_layer: Seq2SeqEncoder,
        mention_feedforward: FeedForward,
        antecedent_feedforward: FeedForward,
        feature_size: int,
        max_span_width: int,
        spans_per_word: float,
        max_antecedents: int,
        lexical_dropout: float = 0.2,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer: Optional[RegularizerApplicator] = None,
    ) -> None:
        super().__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._context_layer = context_layer
        self._antecedent_feedforward = TimeDistributed(antecedent_feedforward)
        feedforward_scorer = torch.nn.Sequential(
            TimeDistributed(mention_feedforward),
            TimeDistributed(torch.nn.Linear(mention_feedforward.get_output_dim(), 1)),
        )
        self._mention_pruner = Pruner(feedforward_scorer)
        self._antecedent_scorer = TimeDistributed(
            torch.nn.Linear(antecedent_feedforward.get_output_dim(), 1)
        )

        self._endpoint_span_extractor = EndpointSpanExtractor(
            context_layer.get_output_dim(),
            combination="x,y",
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False,
        )
        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=text_field_embedder.get_output_dim()
        )

        # 10 possible distance buckets.
        self._num_distance_buckets = 10
        self._distance_embedding = Embedding(self._num_distance_buckets, feature_size)

        self._max_span_width = max_span_width
        self._spans_per_word = spans_per_word
        self._max_antecedents = max_antecedents

        self._mention_recall = MentionRecall()
        self._conll_coref_scores = ConllCorefScores()
        if lexical_dropout > 0:
            self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
        else:
            self._lexical_dropout = lambda x: x
        initializer(self)
    def __init__(self,
                 vocab: Vocabulary,
                 span_typer: SpanTyper,
                 embed_size: int,
                 label_namespace: str = 'span_labels',
                 event_namespace: str = 'event_labels'):
        super(ArgumentSpanClassifier, self).__init__()

        self.vocab: Vocabulary = vocab
        self.label_namespace: str = label_namespace
        self.event_namespace: str = event_namespace

        self.embed_size = embed_size
        self.event_embedding_size = 50

        self.event_embeddings: nn.Embedding = nn.Embedding(
            num_embeddings=len(
                vocab.get_token_to_index_vocabulary(
                    namespace=event_namespace)),
            embedding_dim=self.event_embedding_size)

        self.lexical_dropout = nn.Dropout(p=0.2)
        self.span_extractor: SpanExtractor = EndpointSpanExtractor(
            input_dim=self.embed_size, combination='x,y')
        self.attentive_span_extractor: SpanExtractor = SelfAttentiveSpanExtractor(
            embed_size)

        self.arg_affine = TimeDistributed(
            FeedForward(input_dim=self.span_extractor.get_output_dim() +
                        self.attentive_span_extractor.get_output_dim(),
                        hidden_dims=self.embed_size,
                        num_layers=2,
                        activations=nn.GELU(),
                        dropout=0.2))
        self.trigger_affine = FeedForward(
            input_dim=self.span_extractor.get_output_dim() +
            self.attentive_span_extractor.get_output_dim(),
            hidden_dims=self.embed_size - self.event_embedding_size,
            num_layers=2,
            activations=nn.GELU(),
            dropout=0.2)

        self.trigger_event_infusion = TimeDistributed(
            FeedForward(input_dim=2 * self.embed_size,
                        hidden_dims=self.embed_size,
                        num_layers=2,
                        activations=nn.GELU(),
                        dropout=0.2))

        self.span_typer: SpanTyper = span_typer

        self.apply(self._init_weights)
예제 #7
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        context_layer: Seq2SeqEncoder,
        modules: Params,
        loss_weights: Dict[str, int],
        lexical_dropout: float = 0.2,
        initializer: InitializerApplicator = InitializerApplicator(),
        regularizer=None,
        #regularizer: Optional[GbiRegularizerApplicator] = None,
        display_metrics: List[str] = None,
    ) -> None:
        super(ScirexModel, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._context_layer = context_layer
        self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)

        modules = Params(modules)

        self._ner = NERTagger.from_params(vocab=vocab,
                                          params=modules.pop("ner"))
        self._saliency_classifier = SpanClassifier.from_params(
            vocab=vocab, params=modules.pop("saliency_classifier"))
        self._cluster_n_ary_relation = NAryRelationExtractor.from_params(
            vocab=vocab, params=modules.pop("n_ary_relation"))

        self._endpoint_span_extractor = EndpointSpanExtractor(
            context_layer.get_output_dim(), combination="x,y")
        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=context_layer.get_output_dim())

        for k in loss_weights:
            loss_weights[k] = float(loss_weights[k])
        self._loss_weights = loss_weights
        self._permanent_loss_weights = copy.deepcopy(self._loss_weights)

        self._display_metrics = display_metrics
        self._multi_task_loss_metrics = {
            k: Average()
            for k in ["ner", "saliency", "n_ary_relation"]
        }

        self.training_mode = True
        self.prediction_mode = False

        initializer(self)
예제 #8
0
    def __init__(self,
                 node_types_vocabulary=None,
                 node_attrs_vocabulary=None,
                 p2p_edges_vocabulary=None,
                 p2r_edges_vocabulary=None,
                 bilstm_hidden_embedding_dim=200,
                 lexical_dropout=0.5,
                 lstm_dropout=0.4,
                 max_span_width=15,
                 feature_size=20,
                 embed_mode='bert-base-cased',
                 device=torch.device("cuda")):
        super().__init__()
        self.node_types_vocabulary = node_types_vocabulary
        self.node_attrs_vocabulary = node_attrs_vocabulary
        self.p2p_edges_vocabulary = p2p_edges_vocabulary
        self.p2r_edges_vocabulary = p2r_edges_vocabulary
        self.bilstm_hidden_embedding_dim = bilstm_hidden_embedding_dim
        self.lexical_dropout = lexical_dropout
        self.lstm_dropout = lstm_dropout
        self.embed_mode = embed_mode
        self.device = device
        self.max_span_width = max_span_width
        self.feature_size = feature_size

        if self.embed_mode == 'bert-base-cased':
            self.bert = AutoModel.from_pretrained("bert-base-cased")
            self.bert_hidden_embedding_dim = 768
        if lexical_dropout > 0:
            self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
        else:
            self._lexical_dropout = lambda x: x
        self.bilstm = LSTM(input_size=self.bert_hidden_embedding_dim,
                           hidden_size=self.bilstm_hidden_embedding_dim,
                           dropout=self.lstm_dropout,
                           bidirectional=True,
                           num_layers=6)
        self._endpoint_span_extractor = EndpointSpanExtractor(
            self.bilstm_hidden_embedding_dim,
            combination="x,y",
            num_width_embeddings=self.max_span_width,
            span_width_embedding_dim=self.feature_size,
            bucket_widths=False,
        )
        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=self.bert_hidden_embedding_dim)
예제 #9
0
    def __init__(self, bert_hidden_size: int, cnn_context: int, hidden_size: int):
        super().__init__()
        self.bert_hidden_size = bert_hidden_size
        self.cnn_context = cnn_context
        self.proj_dim = 64
        self.k = 1 + 2 * self.cnn_context
        self.hidden_size = hidden_size

        self.span_extractor = SelfAttentiveSpanExtractor(self.proj_dim)  # span extractor comes directly after BERT
        # all the main parameters are coming from the conv layer
        self.context_conv = nn.Conv1d(self.bert_hidden_size, self.proj_dim, kernel_size=self.k, stride=1,
                                      padding=self.cnn_context, dilation=1, groups=1, bias=True)

        self.fc = nn.Sequential(
            nn.BatchNorm1d(self.proj_dim * 3),
            #             nn.Dropout(0.7),
            nn.Linear(self.proj_dim * 3, self.hidden_size),
            nn.ReLU(),
            nn.BatchNorm1d(self.hidden_size),
            nn.Dropout(0.6),
        )

        self.new_last_layer = nn.Linear(self.hidden_size + 9 + 1 + 3 + 2 + 2, 3)
        # 64 are from proj_dim, 2 are from url, 9 is for the other features, 1 is gender,
        # 3 are synt distance, 2 are the distances to the root

        # after fine-tuning BERT this is not required, throw away
        for i, module in enumerate(self.fc):
            if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)
                print("Initing batchnorm")
            elif isinstance(module, nn.Linear):
                if getattr(module, "weight_v", None) is not None:
                    nn.init.uniform_(module.weight_g, 0, 1)
                    nn.init.kaiming_normal_(module.weight_v)
                    print("Initing linear with weight normalization")
                    assert model[i].weight_g is not None
                else:
                    nn.init.kaiming_normal_(module.weight)
                    print("Initing linear")
                nn.init.constant_(module.bias, 0)
예제 #10
0
    def __init__(self, input_dim, \

            # FFNN projection parameters


            project = True,
            hidden_dim = 100,
            activation = 'tanh',
            dropout = 0.0,

            # General config
            span_end_is_exclusive = True,

            ):

        super(SpanEmbedder, self).__init__()

        # Span end is exclusive (like Python or C)
        self.span_end_is_exclusive = bool(span_end_is_exclusive)
        '''
        Self-attentive span extractor
        '''
        # Create extractor
        self.extractor = SelfAttentiveSpanExtractor( \
                            input_dim = input_dim)
        '''
        Nonlinear projection via feedforward neural network
        '''
        self.project = project
        if self.project:
            self.ffnn = FeedForward( \
                    input_dim = input_dim,
                    num_layers = 1,
                    hidden_dims = hidden_dim,
                    activations = get_activation(activation),
                    dropout = dropout)
            self.output_dim = hidden_dim
        else:
            self.output_dim = input_dim
예제 #11
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 classifier_feedforward: FeedForward,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(RNNClassifier, self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.num_classes = self.vocab.get_vocab_size("labels")
        self.encoder = encoder
        self.endpoint_span_extractor = EndpointSpanExtractor(
            encoder.get_output_dim(), combination="x,y")
        self.attentive_span_extractor = SelfAttentiveSpanExtractor(
            encoder.get_output_dim())

        attention_input_dim = encoder.get_output_dim() * 2
        self.holder_attention = nn.Linear(attention_input_dim, 1)
        self.target_attention = nn.Linear(attention_input_dim, 1)

        self.classifier_feedforward = classifier_feedforward

        if text_field_embedder.get_output_dim() != encoder.get_input_dim():
            raise ConfigurationError(
                "The output dimension of the text_field_embedder must match the "
                "input dimension of the title_encoder. Found {} and {}, "
                "respectively.".format(text_field_embedder.get_output_dim(),
                                       encoder.get_input_dim()))
        self.metrics = {
            "f1_neg": F1Measure(1),
            "f1_none": F1Measure(0),
            "f1_pos": F1Measure(2),
        }
        self.loss = torch.nn.CrossEntropyLoss()

        initializer(self)
예제 #12
0
    def __init__(self,
                 vocab: Vocabulary,
                 span_graph_encoder: SpanGraphEncoder,
                 span_typer: SpanTyper,
                 embed_size: int,
                 label_namespace: str = 'span_labels',
                 event_namespace: str = 'event_labels',
                 use_event_embedding: bool = True):
        super(SelectorArgLinking, self).__init__()

        self.vocab: Vocabulary = vocab
        self.label_namespace: str = label_namespace
        self.event_namespace: str = event_namespace

        self.use_event_embedding = use_event_embedding
        self.embed_size = embed_size
        self.event_embedding_size = 50

        # self.span_finder: SpanFinder = span_finder
        # self.span_selector: SpanSelector = span_selector
        if use_event_embedding:
            self.event_embeddings: nn.Embedding = nn.Embedding(
                num_embeddings=len(vocab.get_token_to_index_vocabulary(namespace=event_namespace)),
                embedding_dim=self.event_embedding_size
            )

        self.lexical_dropout = nn.Dropout(p=0.2)
        # self.contextualized_encoder: Seq2SeqEncoder = LstmSeq2SeqEncoder(
        #     bidirectional=True,
        #     input_size=embed_size,
        #     hidden_size=embed_size,
        #     num_layers=2,
        #     dropout=0.4
        # )
        self.span_graph_encoder: SpanGraphEncoder = span_graph_encoder
        self.span_extractor: SpanExtractor = EndpointSpanExtractor(
            # input_dim=self.contextualized_encoder.get_output_dim(),
            input_dim=self.embed_size,
            combination='x,y'
        )
        self.attentive_span_extractor: SpanExtractor = SelfAttentiveSpanExtractor(embed_size)

        self.arg_affine = TimeDistributed(FeedForward(
            input_dim=self.span_extractor.get_output_dim() + self.attentive_span_extractor.get_output_dim(),
            hidden_dims=self.span_graph_encoder.get_input_dim(),
            num_layers=2,
            activations=nn.GELU(),
            dropout=0.2
        ))
        self.trigger_affine = FeedForward(
            input_dim=self.span_extractor.get_output_dim() + self.attentive_span_extractor.get_output_dim(),
            hidden_dims=self.span_graph_encoder.get_input_dim() - (
                self.event_embedding_size if use_event_embedding else 0),
            num_layers=2,
            activations=nn.GELU(),
            dropout=0.2
        )
        # self.arg_affine: nn.Linear = nn.Linear(
        #     self.span_extractor.get_output_dim() + self.attentive_span_extractor.get_output_dim(),
        #     self.span_graph_encoder.get_input_dim()
        # )
        # self.trigger_affine: nn.Linear = nn.Linear(
        #     self.span_extractor.get_output_dim() + self.attentive_span_extractor.get_output_dim(),
        #     self.span_graph_encoder.get_input_dim()
        # )

        # self.trigger_event_infuse: nn.Sequential = nn.Sequential(
        #     nn.Dropout(p=0.1),
        #     nn.Linear(4 * self.span_graph_encoder.get_input_dim(), 2 * self.span_graph_encoder.get_input_dim()),
        #     nn.Dropout(p=0.1),
        #     nn.GELU(),
        #     nn.Linear(2 * self.span_graph_encoder.get_input_dim(), self.span_graph_encoder.get_input_dim()),
        #     nn.Dropout(p=0.1),
        #     nn.GELU()
        # )

        self.span_typer: SpanTyper = span_typer

        self.apply(self._init_weights)
예제 #13
0
    def __init__(
        self,
        vocab: Vocabulary,
        bert_model: Union[str, BertModel],
        mention_feedforward: FeedForward,
        context_layer: Seq2SeqEncoder = None,
        embedding_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        max_span_width: int = 30,
        feature_size: int = 10,
        spans_per_word: float = 100,
        label_smoothing: float = None,
        ignore_span_metric: bool = False,
        srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        if isinstance(bert_model, str):
            self.bert_model = BertModel.from_pretrained(bert_model)
        else:
            self.bert_model = bert_model

        self.num_classes = self.vocab.get_vocab_size("span_labels")
        if srl_eval_path is not None:
            # For the span based evaluation, we don't want to consider labels
            # for verb, because the verb index is provided to the model.
            self.span_metric = SrlEvalScorer(srl_eval_path,
                                             ignore_classes=["V"])
        else:
            self.span_metric = None
        self.tag_projection_layer = Linear(self.bert_model.config.hidden_size,
                                           self.num_classes)

        self.embedding_dropout = Dropout(p=embedding_dropout)
        self._label_smoothing = label_smoothing
        self.ignore_span_metric = ignore_span_metric

        self._mention_feedforward = TimeDistributed(mention_feedforward)
        self._mention_scorer = TimeDistributed(
            torch.nn.Linear(mention_feedforward.get_output_dim(), 1))

        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=self.bert_model.config.hidden_size)
        self.span_representation_dim = self._attentive_span_extractor.get_output_dim(
        )
        self._context_layer = context_layer
        if context_layer is not None:
            self._endpoint_span_extractor = EndpointSpanExtractor(
                context_layer.get_output_dim(),
                combination="x,y",
                num_width_embeddings=max_span_width,
                span_width_embedding_dim=feature_size,
                bucket_widths=False,
            )
            self.span_representation_dim = self._endpoint_span_extractor.get_output_dim(
            )

        self.hidden_layer = torch.nn.Sequential(
            torch.nn.Linear(self.span_representation_dim +
                            self.bert_model.config.hidden_size,
                            self.span_representation_dim,
                            bias=False), torch.nn.ReLU())
        self.output_layer = torch.nn.Linear(self.span_representation_dim,
                                            self.num_classes - 1,
                                            bias=False)

        self._max_span_width = max_span_width
        self._spans_per_word = spans_per_word
        self._ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
        self._bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
        initializer(self)
 def _make_span_extractor(self):
     if self.span_pooling == "attn":
         return SelfAttentiveSpanExtractor(self.proj_dim)
     else:
         return EndpointSpanExtractor(self.proj_dim,
                                      combination=self.span_pooling)
예제 #15
0
    def __init__(
            self,
            vocab: Vocabulary,
            text_field_embedder: TextFieldEmbedder,
            context_layer: Seq2SeqEncoder,
            modules,  # TODO(dwadden) Add type.
            feature_size: int,
            max_span_width: int,
            loss_weights: Dict[str, int],
            lexical_dropout: float = 0.2,
            lstm_dropout: float = 0.4,
            use_attentive_span_extractor: bool = False,
            co_train: bool = False,
            initializer: InitializerApplicator = InitializerApplicator(),
            regularizer: Optional[RegularizerApplicator] = None,
            display_metrics: List[str] = None) -> None:
        super(DyGIE, self).__init__(vocab, regularizer)

        self._text_field_embedder = text_field_embedder
        self._context_layer = context_layer

        self._loss_weights = loss_weights
        self._permanent_loss_weights = copy.deepcopy(self._loss_weights)

        # Need to add this line so things don't break. TODO(dwadden) sort out what's happening.
        modules = Params(modules)
        self._coref = CorefResolver.from_params(vocab=vocab,
                                                feature_size=feature_size,
                                                params=modules.pop("coref"))
        self._ner = NERTagger.from_params(vocab=vocab,
                                          feature_size=feature_size,
                                          params=modules.pop("ner"))
        self._relation = RelationExtractor.from_params(
            vocab=vocab,
            feature_size=feature_size,
            params=modules.pop("relation"))
        self._events = EventExtractor.from_params(vocab=vocab,
                                                  feature_size=feature_size,
                                                  params=modules.pop("events"))

        # Make endpoint span extractor.

        self._endpoint_span_extractor = EndpointSpanExtractor(
            context_layer.get_output_dim(),
            combination="x,y",
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False)
        if use_attentive_span_extractor:
            self._attentive_span_extractor = SelfAttentiveSpanExtractor(
                input_dim=text_field_embedder.get_output_dim())
        else:
            self._attentive_span_extractor = None

        self._max_span_width = max_span_width

        self._display_metrics = display_metrics

        if lexical_dropout > 0:
            self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
        else:
            self._lexical_dropout = lambda x: x

        # Do co-training if we're training on ACE and ontonotes.
        self._co_train = co_train

        # Big gotcha: PyTorch doesn't add dropout to the LSTM's output layer. We need to do this
        # manually.
        if lstm_dropout > 0:
            self._lstm_dropout = torch.nn.Dropout(p=lstm_dropout)
        else:
            self._lstm_dropout = lambda x: x

        initializer(self)
예제 #16
0
    def __init__(
        self,
        encoder,
        encoder_pretrained,
        encoder_frozen,
        decoder_hidden,
        embeddings,
        max_layer=12,
        src_pad_idx=0,
        encoder_hidden=None,
        variational=None,
        latent_size=None,
        scalar_mix=False,
        aggregator="mean",
        teacher_forcing_p=0.3,
        classification=None,
        attentional=False,
        definition_encoder=None,
        word_dropout_p=None,
        decoder_num_layers=None,
    ):
        super(DefinitionProbing, self).__init__()

        self.embeddings = embeddings
        self.variational = variational
        self.encoder_hidden = encoder_hidden
        self.decoder_hidden = decoder_hidden
        self.decoder_num_layers = decoder_num_layers
        self.encoder = encoder
        self.latent_size = latent_size
        self.src_pad_idx = src_pad_idx
        if encoder_pretrained:
            self.encoder_hidden = self.encoder.config.hidden_size
        if encoder_frozen:
            for param in self.encoder.parameters():
                param.requires_grad = False
        self.max_layer = max_layer
        self.aggregator = aggregator
        if self.aggregator == "span":
            self.span_extractor = SelfAttentiveSpanExtractor(self.encoder_hidden)
        self.context_feed_forward = nn.Linear(self.encoder_hidden, self.encoder_hidden)
        self.scalar_mix = None
        if scalar_mix:
            self.scalar_mix = ScalarMix(self.max_layer + 1)
        self.global_scorer = GNMTGlobalScorer(
            alpha=2, beta=None, length_penalty="avg", coverage_penalty=None
        )

        self.decoder = LSTM_Decoder(
            embeddings.tgt,
            hidden=self.decoder_hidden,
            encoder_hidden=self.encoder_hidden,
            num_layers=self.decoder_num_layers,
            word_dropout=word_dropout_p,
            teacher_forcing_p=teacher_forcing_p,
            attention="general" if attentional else None,
            dropout=DotMap({"input": 0.5, "output": 0.5}),
            decoder="VDM" if self.variational else "LSTM",
            variational=self.variational,
            latent_size=self.latent_size,
        )

        self.target_kl = 1.0
        if self.variational:
            self.definition_encoder = definition_encoder
            self.definition_feed_forward = nn.Linear(
                self.encoder_hidden, self.encoder_hidden
            )
            self.mean_layer = nn.Linear(self.latent_size, self.latent_size)
            self.logvar_layer = nn.Linear(self.latent_size, self.latent_size)
            self.w_z_post = nn.Sequential(
                nn.Linear(self.encoder_hidden * 2, self.latent_size), nn.Tanh()
            )
            self.mean_prime_layer = nn.Linear(self.latent_size, self.latent_size)
            self.logvar_prime_layer = nn.Linear(self.latent_size, self.latent_size)
            self.w_z_prior = nn.Sequential(
                nn.Linear(self.encoder_hidden, self.latent_size), nn.Tanh()
            )
            self.z_project = nn.Sequential(
                nn.Linear(self.latent_size, self.decoder_hidden), nn.Tanh()
            )
예제 #17
0
 def _make_span_extractor(self):
     return SelfAttentiveSpanExtractor(self.proj_dim)
예제 #18
0
    def __init__(self, num_labels, vocab_size, word_embeddings_size,
                 hidden_dim, word_embeddings, num_polarities, batch_size,
                 dropout_rate, max_co_occurs, feature_embeddings_size=25,
                 ablations=None):
        super(Model1, self).__init__()
        self.dropout = nn.Dropout(dropout_rate)

        self.hidden_dim = hidden_dim
        self.batch_size = batch_size

        # Specify embedding layers
        self.word_embeds = nn.Embedding(vocab_size, word_embeddings_size)
        self.word_embeds.weight.data.copy_(torch.FloatTensor(word_embeddings))
        # self.word_embeds.weight.requires_grad = False  # don't update the embeddings
        self.polarity_embeds = nn.Embedding(num_polarities + 1, word_embeddings_size)  # add 1 for <pad>
        self.co_occur_embeds = nn.Embedding(max_co_occurs, feature_embeddings_size)
        self.holder_target_embeds = nn.Embedding(5, word_embeddings_size)  # add 2 for <pad> and <unk>
        self.num_holder_mention_embeds = nn.Embedding(num_mentions_cats, feature_embeddings_size)
        self.num_target_mention_embeds = nn.Embedding(num_mentions_cats, feature_embeddings_size)
        self.min_mention_embeds = nn.Embedding(num_mentions_cats, feature_embeddings_size)
        self.holder_rank_embeds = nn.Embedding(5, feature_embeddings_size)
        self.target_rank_embeds = nn.Embedding(5, feature_embeddings_size)
        self.sent_classify_embeds = nn.Embedding(num_labels, feature_embeddings_size)

        # The LSTM takes [word embeddings, feature embeddings, holder/target embeddings] as inputs, and
        # outputs hidden states with dimensionality hidden_dim.
        self.lstm = nn.LSTM(2 * word_embeddings_size, hidden_dim, num_layers=2,
                            batch_first=True, bidirectional=True, dropout=dropout_rate)

        # The linear layer that maps from hidden state space to target space
        self.hidden2label = nn.Linear(2 * hidden_dim, num_labels)

        # Matrix of weights for each layer
        # Linear map from hidden layers to alpha for that layer
        # self.attention = nn.Linear(2 * hidden_dim, 1)
        # Attempting feedforward attention, using 2 layers and sigmoid activation fxn
        # Last layer acts as the w_alpha layer
        self.attention = FeedForward(input_dim=2 * hidden_dim, num_layers=2, hidden_dims=[hidden_dim, 1],
                                     activations=nn.Sigmoid())

        # Span embeddings
        self._endpoint_span_extractor = EndpointSpanExtractor(2 * hidden_dim,
                                                              combination="x,y")
        self._attentive_span_extractor = SelfAttentiveSpanExtractor(input_dim=2 * hidden_dim)

        # FFNN for holder/target spans respectively
        self.holder_FFNN = nn.Linear(in_features=3 * 2 * hidden_dim, out_features=3 * 2 * hidden_dim)
        self.target_FFNN = nn.Linear(in_features=3 * 2 * hidden_dim, out_features=3 * 2 * hidden_dim)
#        self.holder_FFNN = FeedForward(input_dim=3 * 2 * hidden_dim, num_layers=2, hidden_dims=[3 * 2 * hidden_dim, 3 * hidden_dim], activations=nn.ReLU())
#        self.target_FFNN = FeedForward(input_dim=3 * 2 * hidden_dim, num_layers=2, hidden_dims=[3 * 2 * hidden_dim, 3 * hidden_dim], activations=nn.ReLU())

        # linear for attention to each pair
        self.pair_attention = nn.Linear(in_features=6 * hidden_dim, out_features=1)
        # self.pairwise_sentiment_score = nn.Linear(in_features=12 * hidden_dim, out_features=6 * hidden_dim)

        # Scoring pairwise sentiment: linear score approach
        '''
        self.final_sentiment_score = FeedForward(input_dim=15 * hidden_dim, num_layers=2,
                                                    hidden_dims=[hidden_dim, num_labels],
                                                    activations=nn.ReLU())
        '''
        # Any possible ablations
        self.ablations = ablations
        ft_dim = 12 * hidden_dim + 6 * feature_embeddings_size
        if ablations is not None:
            ft_dim = 12 * hidden_dim + ABLATION_TO_DIM[ablations] * feature_embeddings_size
        self.final_sentiment_score = nn.Linear(in_features=ft_dim, out_features=num_labels)
예제 #19
0
    def __init__(self,
                 vocab: Vocabulary,
                 text_field_embedder: TextFieldEmbedder,
                 encoder: Seq2SeqEncoder,
                 tag_representation_dim: int,
                 arc_representation_dim: int,
                 tag_feedforward: FeedForward = None,
                 arc_feedforward: FeedForward = None,
                 pos_tag_embedding: Embedding = None,
                 use_mst_decoding_for_validation: bool = True,
                 dropout: float = 0.0,
                 input_dropout: float = 0.0,
                 initializer: InitializerApplicator = InitializerApplicator(),
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(BiaffineChineseDependencyParser,
              self).__init__(vocab, regularizer)

        self.text_field_embedder = text_field_embedder
        self.encoder = encoder

        encoder_dim = encoder.get_output_dim()

        self.head_arc_feedforward = arc_feedforward or \
                                    FeedForward(encoder_dim, 1,
                                                arc_representation_dim,
                                                Activation.by_name("elu")())
        self.child_arc_feedforward = copy.deepcopy(self.head_arc_feedforward)

        self.arc_attention = BilinearMatrixAttention(arc_representation_dim,
                                                     arc_representation_dim,
                                                     use_input_biases=True)

        num_labels = self.vocab.get_vocab_size("head_tags")

        self.head_tag_feedforward = tag_feedforward or \
                                    FeedForward(encoder_dim, 1,
                                                tag_representation_dim,
                                                Activation.by_name("elu")())
        self.child_tag_feedforward = copy.deepcopy(self.head_tag_feedforward)

        self.tag_bilinear = torch.nn.modules.Bilinear(tag_representation_dim,
                                                      tag_representation_dim,
                                                      num_labels)

        self._pos_tag_embedding = pos_tag_embedding or None
        self._dropout = InputVariationalDropout(dropout)
        self._input_dropout = Dropout(input_dropout)
        self._head_sentinel = torch.nn.Parameter(
            torch.randn([1, 1, encoder.get_output_dim()]))

        representation_dim = text_field_embedder.get_output_dim()
        if pos_tag_embedding is not None:
            representation_dim += pos_tag_embedding.get_output_dim()

        # check_dimensions_match(representation_dim, encoder.get_input_dim(),
        #                        "text field embedding dim", "encoder input dim")

        check_dimensions_match(tag_representation_dim,
                               self.head_tag_feedforward.get_output_dim(),
                               "tag representation dim",
                               "tag feedforward output dim")
        check_dimensions_match(arc_representation_dim,
                               self.head_arc_feedforward.get_output_dim(),
                               "arc representation dim",
                               "arc feedforward output dim")

        self.use_mst_decoding_for_validation = use_mst_decoding_for_validation

        tags = self.vocab.get_token_to_index_vocabulary("pos")
        punctuation_tag_indices = {
            tag: index
            for tag, index in tags.items() if tag in POS_TO_IGNORE
        }
        self._pos_to_ignore = set(punctuation_tag_indices.values())
        logger.info(
            f"Found POS tags corresponding to the following punctuation : {punctuation_tag_indices}. "
            "Ignoring words with these POS tags for evaluation.")

        self._attachment_scores = AttachmentScores()

        self._endpoint_span_extractor = EndpointSpanExtractor(
            self.text_field_embedder.get_output_dim(),
            combination="x,y",
            bucket_widths=False)
        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=self.text_field_embedder.get_output_dim())

        initializer(self)
    def test_attention_is_normalised_correctly(self):
        input_dim = 7
        sequence_tensor = torch.randn([2, 5, input_dim])
        extractor = SelfAttentiveSpanExtractor(input_dim=input_dim)
        assert extractor.get_output_dim() == input_dim
        assert extractor.get_input_dim() == input_dim

        # In order to test the attention, we'll make the weight which computes the logits
        # zero, so the attention distribution is uniform over the sentence. This lets
        # us check that the computed spans are just the averages of their representations.
        extractor._global_attention._module.weight.data.fill_(0.0)
        extractor._global_attention._module.bias.data.fill_(0.0)

        indices = torch.LongTensor([[[1, 3],
                                     [2, 4]],
                                    [[0, 2],
                                     [3, 4]]]) # smaller span tests masking.
        span_representations = extractor(sequence_tensor, indices)
        assert list(span_representations.size()) == [2, 2, input_dim]

        # First element in the batch.
        batch_element = 0
        spans = span_representations[batch_element]
        # First span.
        mean_embeddings = sequence_tensor[batch_element, 1:4, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy())
        # Second span.
        mean_embeddings = sequence_tensor[batch_element, 2:5, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), mean_embeddings.data.numpy())
        # Now the second element in the batch.
        batch_element = 1
        spans = span_representations[batch_element]
        # First span.
        mean_embeddings = sequence_tensor[batch_element, 0:3, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy())
        # Second span.
        mean_embeddings = sequence_tensor[batch_element, 3:5, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), mean_embeddings.data.numpy())


        # Now test the case in which we have some masked spans in our indices.
        indices_mask = torch.LongTensor([[1, 1], [1, 0]])
        span_representations = extractor(sequence_tensor, indices, span_indices_mask=indices_mask)

        # First element in the batch.
        batch_element = 0
        spans = span_representations[batch_element]
        # First span.
        mean_embeddings = sequence_tensor[batch_element, 1:4, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy())
        # Second span.
        mean_embeddings = sequence_tensor[batch_element, 2:5, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), mean_embeddings.data.numpy())
        # Now the second element in the batch.
        batch_element = 1
        spans = span_representations[batch_element]
        # First span.
        mean_embeddings = sequence_tensor[batch_element, 0:3, :].mean(0)
        numpy.testing.assert_array_almost_equal(spans[0].data.numpy(), mean_embeddings.data.numpy())
        # Second span was masked, so should be completely zero.
        numpy.testing.assert_array_almost_equal(spans[1].data.numpy(), numpy.zeros([input_dim]))
예제 #21
0
class SCIIE(Model):
    """
    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``text`` ``TextField`` we get as input to the model.
    context_layer : ``Seq2SeqEncoder``
        This layer incorporates contextual information for each word in the document.
    mention_feedforward : ``FeedForward``
        This feedforward network is applied to the span representations which is then scored
        by a linear layer.
    antecedent_feedforward: ``FeedForward``
        This feedforward network is applied to pairs of span representation, along with any
        pairwise features, which is then scored by a linear layer.
    feature_size: ``int``
        The embedding size for all the embedded features, such as distances or span widths.
    max_span_width: ``int``
        The maximum width of candidate spans.
    spans_per_word: float, required.
        A multiplier between zero and one which controls what percentage of candidate mention
        spans we retain with respect to the number of words in the document.
    max_antecedents: int, required.
        For each mention which survives the pruning stage, we consider this many antecedents.
    lexical_dropout: ``int``
        The probability of dropping out dimensions of the embedded text.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    """
    def __init__(self,
                 vocab: Vocabulary,
                 embedding_dim: int,
                 feature_size: int,
                 max_span_width: int,
                 spans_per_word: float,
                 lexical_dropout: float = 0.2,
                 mlp_dropout: float = 0.4,
                 embedder_type=None,
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(SCIIE, self).__init__(vocab, regularizer)
        self.class_num = self.vocab.get_vocab_size('labels')
        word_embeddings = get_embeddings(embedder_type, self.vocab,
                                         embedding_dim, True)
        embedding_dim = word_embeddings.get_output_dim()
        self._text_field_embedder = word_embeddings

        context_layer = PytorchSeq2SeqWrapper(
            torch.nn.LSTM(embedding_dim,
                          feature_size,
                          batch_first=True,
                          bidirectional=True))
        self._context_layer = context_layer

        endpoint_span_extractor_input_dim = context_layer.get_output_dim()
        attentive_span_extractor_input_dim = word_embeddings.get_output_dim()

        self._endpoint_span_extractor = EndpointSpanExtractor(
            endpoint_span_extractor_input_dim,
            combination="x,y",
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False)
        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=attentive_span_extractor_input_dim)

        # self._span_extractor = PoolingSpanExtractor(embedding_dim,
        #                                             num_width_embeddings=max_span_width,
        #                                             span_width_embedding_dim=feature_size,
        #                                             bucket_widths=False)

        entity_feedforward = FeedForward(
            self._endpoint_span_extractor.get_output_dim() +
            self._attentive_span_extractor.get_output_dim(), 2, 150, F.relu,
            mlp_dropout)
        # entity_feedforward = FeedForward(self._span_extractor.get_output_dim(), 2, 150,
        #                                  F.relu, mlp_dropout)

        feedforward_scorer = torch.nn.Sequential(
            TimeDistributed(entity_feedforward),
            TimeDistributed(
                torch.nn.Linear(entity_feedforward.get_output_dim(), 1)))
        self._mention_pruner = Pruner(feedforward_scorer)

        self._entity_scorer = torch.nn.Sequential(
            TimeDistributed(entity_feedforward),
            TimeDistributed(
                torch.nn.Linear(entity_feedforward.get_output_dim(),
                                self.class_num - 1)))

        self._max_span_width = max_span_width
        self._spans_per_word = spans_per_word
        if lexical_dropout > 0:
            self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
        else:
            self._lexical_dropout = lambda x: x

        self._metric_all = FBetaMeasure()
        self._metric_avg = NERF1Metric()

    @overrides
    def forward(
            self,  # type: ignore
            text: Dict[str, torch.LongTensor],
            spans: torch.IntTensor,
            labels: torch.IntTensor = None,
            metadata: List[Dict[str, Any]] = None,
            **kwargs) -> Dict[str, torch.Tensor]:
        # Shape: (batch_size, document_length, embedding_size)
        text_embeddings = self._lexical_dropout(
            self._text_field_embedder(text))

        document_length = text_embeddings.size(1)
        num_spans = spans.size(1)

        # Shape: (batch_size, document_length)
        text_mask = util.get_text_field_mask(text).float()

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1).float()
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        # Shape: (batch_size, document_length, encoding_dim)
        contextualized_embeddings = self._context_layer(
            text_embeddings, text_mask)

        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        endpoint_span_embeddings = self._endpoint_span_extractor(
            contextualized_embeddings, spans)
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            text_embeddings, spans)

        # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
        span_embeddings = torch.cat(
            [endpoint_span_embeddings, attended_span_embeddings], -1)
        # span_embeddings = self._span_extractor(text_embeddings, spans, span_indices_mask=span_mask)

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * document_length))
        num_spans_to_keep = min(num_spans_to_keep, span_embeddings.shape[1])

        # Shape:    (batch_size, num_spans_to_keep, emebedding_size + 2 * encoding_dim + feature_size)
        #           (batch_size, num_spans_to_keep)
        #           (batch_size, num_spans_to_keep)
        #           (batch_size, num_spans_to_keep, 1)
        (top_span_embeddings, top_span_mask, top_span_indices,
         top_span_mention_scores) = self._mention_pruner(
             span_embeddings, span_mask, num_spans_to_keep)
        # (batch_size, num_spans_to_keep, 1)
        top_span_mask = top_span_mask.unsqueeze(-1)
        # Shape: (batch_size * num_spans_to_keep)
        # torch.index_select only accepts 1D indices, but here
        # we need to select spans for each element in the batch.
        # This reformats the indices to take into account their
        # index into the batch. We precompute this here to make
        # the multiple calls to util.batched_index_select below more efficient.
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, num_spans)

        # Compute final predictions for which spans to consider as mentions.
        # Shape: (batch_size, num_spans_to_keep, 2)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)

        # Shape: (batch_size, num_spans_to_keep, class_num + 1)
        ne_scores = self._compute_named_entity_scores(top_span_embeddings)

        # Shape: (batch_size, num_spans_to_keep)
        _, predicted_named_entities = ne_scores.max(2)

        output_dict = {
            "top_spans": top_spans,
            "predicted_named_entities": predicted_named_entities
        }
        if labels is not None:
            # Find the gold labels for the spans which we kept.
            # Shape: (batch_size, num_spans_to_keep, 1)
            pruned_gold_labels = util.batched_index_select(
                labels.unsqueeze(-1), top_span_indices,
                flat_top_span_indices).squeeze(-1)
            negative_log_likelihood = F.cross_entropy(
                ne_scores.reshape(-1, self.class_num),
                pruned_gold_labels.reshape(-1))

            pruner_loss = F.binary_cross_entropy_with_logits(
                top_span_mention_scores.reshape(-1),
                (pruned_gold_labels.reshape(-1) != 0).float())
            loss = negative_log_likelihood + pruner_loss
            output_dict["loss"] = loss
            output_dict["pruner_loss"] = pruner_loss
            batch_size, _ = labels.shape
            all_scores = ne_scores.new_zeros(
                [batch_size * num_spans, self.class_num])
            all_scores[:, 0] = 1
            all_scores[flat_top_span_indices] = ne_scores.reshape(
                -1, self.class_num)
            all_scores = all_scores.reshape(
                [batch_size, num_spans, self.class_num])
            self._metric_all(all_scores, labels)
            self._metric_avg(all_scores, labels)
        return output_dict

    @overrides
    def get_metrics(self, reset: bool = False, prefix=""):
        metric = self._metric_all.get_metric(reset)
        metric2 = self._metric_avg.get_metric(reset)
        metric.update(metric2)
        return metric

    def _compute_named_entity_scores(
            self, span_embeddings: torch.FloatTensor) -> torch.Tensor:
        """
        Parameters
        ----------
        span_embeddings: ``torch.FloatTensor``, required.
            Embedding representations of spans. Has shape
            (batch_size, num_spans_to_keep, encoding_dim)
        """
        # Shape: (batch_size, num_spans_to_keep, class_num)
        scores = self._entity_scorer(span_embeddings)
        # Shape: (batch_size, num_spans_to_keep, 1)
        shape = [scores.size(0), scores.size(1), 1]
        dummy_scores = scores.new_full(shape, 0)
        ne_scores = torch.cat([dummy_scores, scores], -1)
        return ne_scores
예제 #22
0
    def __init__(self,
                 contextual_embedding_dim,
                 entity_embedding_dim: int,
                 entity_embeddings: torch.nn.Embedding,
                 max_sequence_length: int = 512,
                 span_encoder_config: Dict[str, int] = None,
                 dropout: float = 0.1,
                 output_feed_forward_hidden_dim: int = 100,
                 initializer_range: float = 0.02,
                 weighted_entity_threshold: float = None,
                 null_entity_id: int = None,
                 include_null_embedding_in_dot_attention: bool = False):
        """
        Idea: Align the bert and KG vector space by learning a mapping between
            them.
        """
        super().__init__()

        self.span_extractor = SelfAttentiveSpanExtractor(entity_embedding_dim)
        init_bert_weights(self.span_extractor._global_attention._module,
                          initializer_range)

        self.dropout = torch.nn.Dropout(dropout)

        self.bert_to_kg_projector = torch.nn.Linear(
            contextual_embedding_dim, entity_embedding_dim)
        init_bert_weights(self.bert_to_kg_projector, initializer_range)
        self.projected_span_layer_norm = BertLayerNorm(entity_embedding_dim, eps=1e-5)
        init_bert_weights(self.projected_span_layer_norm, initializer_range)

        self.kg_layer_norm = BertLayerNorm(entity_embedding_dim, eps=1e-5)
        init_bert_weights(self.kg_layer_norm, initializer_range)

        # already pretrained, don't init
        self.entity_embeddings = entity_embeddings
        self.entity_embedding_dim = entity_embedding_dim

        # layers for the dot product attention
        if weighted_entity_threshold is not None or include_null_embedding_in_dot_attention:
            if hasattr(self.entity_embeddings, 'get_null_embedding'):
                null_embedding = self.entity_embeddings.get_null_embedding()
            else:
                null_embedding = self.entity_embeddings.weight[null_entity_id, :]
        else:
            null_embedding = None
        self.dot_attention_with_prior = DotAttentionWithPrior(
            output_feed_forward_hidden_dim,
            weighted_entity_threshold,
            null_embedding,
            initializer_range
        )
        self.null_entity_id = null_entity_id
        self.contextual_embedding_dim = contextual_embedding_dim

        if span_encoder_config is None:
            self.span_encoder = None
        else:
            # create BertConfig
            assert len(span_encoder_config) == 4
            config = BertConfig(
                0,  # vocab size, not used
                hidden_size=span_encoder_config['hidden_size'],
                num_hidden_layers=span_encoder_config['num_hidden_layers'],
                num_attention_heads=span_encoder_config['num_attention_heads'],
                intermediate_size=span_encoder_config['intermediate_size']
            )
            self.span_encoder = BertEncoder(config)
            init_bert_weights(self.span_encoder, initializer_range)
예제 #23
0
    def __init__(self,
                 vocab: Vocabulary,
                 embedding_dim: int,
                 feature_size: int,
                 max_span_width: int,
                 spans_per_word: float,
                 lexical_dropout: float = 0.2,
                 mlp_dropout: float = 0.4,
                 embedder_type=None,
                 regularizer: Optional[RegularizerApplicator] = None) -> None:
        super(SCIIE, self).__init__(vocab, regularizer)
        self.class_num = self.vocab.get_vocab_size('labels')
        word_embeddings = get_embeddings(embedder_type, self.vocab,
                                         embedding_dim, True)
        embedding_dim = word_embeddings.get_output_dim()
        self._text_field_embedder = word_embeddings

        context_layer = PytorchSeq2SeqWrapper(
            torch.nn.LSTM(embedding_dim,
                          feature_size,
                          batch_first=True,
                          bidirectional=True))
        self._context_layer = context_layer

        endpoint_span_extractor_input_dim = context_layer.get_output_dim()
        attentive_span_extractor_input_dim = word_embeddings.get_output_dim()

        self._endpoint_span_extractor = EndpointSpanExtractor(
            endpoint_span_extractor_input_dim,
            combination="x,y",
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False)
        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=attentive_span_extractor_input_dim)

        # self._span_extractor = PoolingSpanExtractor(embedding_dim,
        #                                             num_width_embeddings=max_span_width,
        #                                             span_width_embedding_dim=feature_size,
        #                                             bucket_widths=False)

        entity_feedforward = FeedForward(
            self._endpoint_span_extractor.get_output_dim() +
            self._attentive_span_extractor.get_output_dim(), 2, 150, F.relu,
            mlp_dropout)
        # entity_feedforward = FeedForward(self._span_extractor.get_output_dim(), 2, 150,
        #                                  F.relu, mlp_dropout)

        feedforward_scorer = torch.nn.Sequential(
            TimeDistributed(entity_feedforward),
            TimeDistributed(
                torch.nn.Linear(entity_feedforward.get_output_dim(), 1)))
        self._mention_pruner = Pruner(feedforward_scorer)

        self._entity_scorer = torch.nn.Sequential(
            TimeDistributed(entity_feedforward),
            TimeDistributed(
                torch.nn.Linear(entity_feedforward.get_output_dim(),
                                self.class_num - 1)))

        self._max_span_width = max_span_width
        self._spans_per_word = spans_per_word
        if lexical_dropout > 0:
            self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
        else:
            self._lexical_dropout = lambda x: x

        self._metric_all = FBetaMeasure()
        self._metric_avg = NERF1Metric()
예제 #24
0
파일: dygie.py 프로젝트: GillesJ/dygiepp
class DyGIE(Model):
    """
    TODO(dwadden) document me.

    Parameters
    ----------
    vocab : ``Vocabulary``
    text_field_embedder : ``TextFieldEmbedder``
        Used to embed the ``text`` ``TextField`` we get as input to the model.
    context_layer : ``Seq2SeqEncoder``
        This layer incorporates contextual information for each word in the document.
    feature_size: ``int``
        The embedding size for all the embedded features, such as distances or span widths.
    submodule_params: ``TODO(dwadden)``
        A nested dictionary specifying parameters to be passed on to initialize submodules.
    max_span_width: ``int``
        The maximum width of candidate spans.
    max_trigger_span_width: ``int``
        The maximum width of candidate trigger spans.
    target_task: ``str``:
        The task used to make early stopping decisions.
    initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the model parameters.
    module_initializer : ``InitializerApplicator``, optional (default=``InitializerApplicator()``)
        Used to initialize the individual modules.
    regularizer : ``RegularizerApplicator``, optional (default=``None``)
        If provided, will be used to calculate the regularization penalty during training.
    display_metrics: ``List[str]``. A list of the metrics that should be printed out during model
        training.
    """
    def __init__(
            self,
            vocab: Vocabulary,
            embedder: TextFieldEmbedder,
            context_layer: Seq2SeqEncoder,
            modules,  # TODO(dwadden) Add type.
            feature_size: int,
            max_span_width: int,
            max_trigger_span_width: int,
            target_task: str,
            feedforward_params: Dict[str, Union[int, float]],
            loss_weights: Dict[str, float],
            lexical_dropout: float = 0.2,
            use_attentive_span_extractor: bool = False,
            initializer: InitializerApplicator = InitializerApplicator(),
            module_initializer: InitializerApplicator = InitializerApplicator(
            ),
            regularizer: Optional[RegularizerApplicator] = None,
            display_metrics: List[str] = None) -> None:
        super(DyGIE, self).__init__(vocab, regularizer)

        ####################

        # Create span extractor.
        self._endpoint_span_extractor = EndpointSpanExtractor(
            context_layer.get_output_dim(),
            combination="x,y",
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False)
        self._endpoint_trigger_span_extractor = EndpointSpanExtractor(
            context_layer.get_output_dim(),
            combination="x,y",
            num_width_embeddings=max_trigger_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False)

        ####################
        if lexical_dropout > 0:
            self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
        else:
            self._lexical_dropout = lambda x: x

        if use_attentive_span_extractor:
            self._attentive_span_extractor = SelfAttentiveSpanExtractor(
                input_dim=context_layer.get_output_dim())
        else:
            self._attentive_span_extractor = None

        # Set parameters.
        self._embedder = embedder
        self._context_layer = context_layer
        self._loss_weights = loss_weights
        self._max_span_width = max_span_width
        self._max_trigger_span_width = max_trigger_span_width
        self._display_metrics = self._get_display_metrics(target_task)

        trigger_emb_dim = self._endpoint_trigger_span_extractor.get_output_dim(
        )
        span_emb_dim = self._endpoint_span_extractor.get_output_dim()

        if self._attentive_span_extractor is not None:
            span_emb_dim += self._attentive_span_extractor.get_output_dim()
            trigger_emb_dim += self._attentive_span_extractor.get_output_dim()

        ####################

        # Create submodules.

        modules = Params(modules)

        # Helper function to create feedforward networks.
        def make_feedforward(input_dim):
            return FeedForward(input_dim=input_dim,
                               num_layers=feedforward_params["num_layers"],
                               hidden_dims=feedforward_params["hidden_dims"],
                               activations=torch.nn.ReLU(),
                               dropout=feedforward_params["dropout"])

        # Submodules

        self._ner = NERTagger.from_params(vocab=vocab,
                                          make_feedforward=make_feedforward,
                                          span_emb_dim=span_emb_dim,
                                          feature_size=feature_size,
                                          params=modules.pop("ner"))

        self._coref = CorefResolver.from_params(
            vocab=vocab,
            make_feedforward=make_feedforward,
            span_emb_dim=span_emb_dim,
            feature_size=feature_size,
            params=modules.pop("coref"))

        self._relation = RelationExtractor.from_params(
            vocab=vocab,
            make_feedforward=make_feedforward,
            span_emb_dim=span_emb_dim,
            feature_size=feature_size,
            params=modules.pop("relation"))

        self._events = EventExtractor.from_params(
            vocab=vocab,
            make_feedforward=make_feedforward,
            text_emb_dim=self._embedder.get_output_dim(),
            trigger_emb_dim=trigger_emb_dim,
            span_emb_dim=span_emb_dim,
            feature_size=feature_size,
            params=modules.pop("events"))

        ####################

        # Initialize text embedder and all submodules
        for module in [self._ner, self._coref, self._relation, self._events]:
            module_initializer(module)

        initializer(self)

    @staticmethod
    def _get_display_metrics(target_task):
        """
        The `target` is the name of the task used to make early stopping decisions. Show metrics
        related to this task.
        """
        lookup = {
            "ner": [
                f"MEAN__{name}"
                for name in ["ner_precision", "ner_recall", "ner_f1"]
            ],
            "relation": [
                f"MEAN__{name}" for name in
                ["relation_precision", "relation_recall", "relation_f1"]
            ],
            "coref": [
                "coref_precision", "coref_recall", "coref_f1",
                "coref_mention_recall"
            ],
            "events":
            [f"MEAN__{name}" for name in ["trig_class_f1", "arg_class_f1"]]
        }
        if target_task not in lookup:
            raise ValueError(
                f"Invalied value {target_task} has been given as the target task."
            )
        return lookup[target_task]

    @staticmethod
    def _debatch(x):
        # TODO(dwadden) Get rid of this when I find a better way to do it.
        return x if x is None else x.squeeze(0)

    @overrides
    def forward(self,
                text,
                trigger_spans,
                spans,
                metadata,
                ner_labels=None,
                coref_labels=None,
                relation_labels=None,
                trigger_labels=None,
                argument_labels=None):
        """
        TODO(dwadden) change this.
        """
        # In AllenNLP, AdjacencyFields are passed in as floats. This fixes it.
        if relation_labels is not None:
            relation_labels = relation_labels.long()
        if argument_labels is not None:
            argument_labels = argument_labels.long()

        # TODO(dwadden) Multi-document minibatching isn't supported yet. For now, get rid of the
        # extra dimension in the input tensors. Will return to this once the model runs.
        if len(metadata) > 1:
            raise NotImplementedError(
                "Multi-document minibatching not yet supported.")

        metadata = metadata[0]
        spans = self._debatch(spans)  # (n_sents, max_n_spans, 2)
        trigger_spans = self._debatch(
            trigger_spans)  # (n_sents, max_n_spans, 2)
        ner_labels = self._debatch(ner_labels)  # (n_sents, max_n_spans)
        coref_labels = self._debatch(coref_labels)  #  (n_sents, max_n_spans)
        relation_labels = self._debatch(
            relation_labels)  # (n_sents, max_n_spans, max_n_spans)
        trigger_labels = self._debatch(trigger_labels)  # TODO(dwadden)
        argument_labels = self._debatch(argument_labels)  # TODO(dwadden)

        # Encode using BERT, then debatch.
        # Since the data are batched, we use `num_wrapping_dims=1` to unwrap the document dimension.
        # (1, n_sents, max_sententence_length, embedding_dim)

        # TODO(dwadden) Deal with the case where the input is longer than 512.
        text_embeddings = self._embedder(text, num_wrapping_dims=1)
        # (n_sents, max_n_wordpieces, embedding_dim)
        text_embeddings = self._debatch(text_embeddings)
        # apply lexical dropout
        text_embeddings = self._lexical_dropout(text_embeddings)

        # (n_sents, max_sentence_length)
        text_mask = self._debatch(
            util.get_text_field_mask(text, num_wrapping_dims=1).float())
        sentence_lengths = text_mask.sum(dim=1).long()  # (n_sents)

        # contextualize text embeddings
        text_embeddings = self._context_layer(text_embeddings, text_mask)

        # Create spans, i.e. span_embeddings, masks and span_indices
        span_mask = (spans[:, :, 0] >= 0).float()  # (n_sents, max_n_spans)
        # SpanFields return -1 when they are used as padding. As we do some comparisons based on
        # span widths when we attend over the span representations that we generate from these
        # indices, we need them to be <= 0. This is only relevant in edge cases where the number of
        # spans we consider after the pruning stage is >= the total number of spans, because in this
        # case, it is possible we might consider a masked span.
        spans = F.relu(spans.float()).long()  # (n_sents, max_n_spans, 2)

        # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
        span_embeddings = self._endpoint_span_extractor(text_embeddings, spans)

        trigger_mask = (trigger_spans[:, :, 0] >= 0).float()
        trigger_spans = F.relu(trigger_spans.float()).long()
        trigger_embeddings = self._endpoint_trigger_span_extractor(
            text_embeddings, trigger_spans)

        # Make attented spans embeddings
        if self._attentive_span_extractor is not None:
            # Shape: (batch_size, num_spans, embedding_size)
            attended_span_embeddings = self._attentive_span_extractor(
                text_embeddings, spans)
            attended_trigger_span_embeddings = self._attentive_span_extractor(
                text_embeddings, trigger_spans)
            # Shape: (batch_size, num_spans, embedding_size + 2 * encoding_dim + feature_size)
            span_embeddings = torch.cat(
                [span_embeddings, attended_span_embeddings], -1)
            trigger_embeddings = torch.cat(
                [trigger_embeddings, attended_trigger_span_embeddings], -1)

        # Make calls out to the modules to get results.
        output_coref = {'loss': 0}
        output_ner = {'loss': 0}
        output_relation = {'loss': 0}
        output_events = {'loss': 0}

        # Prune and compute span representations for coreference module
        if self._loss_weights["coref"] > 0 or self._coref.coref_prop > 0:
            output_coref, coref_indices = self._coref.compute_representations(
                spans, span_mask, span_embeddings, sentence_lengths,
                coref_labels, metadata)

        # Propagation of global information to enhance the span embeddings
        if self._coref.coref_prop > 0:
            output_coref = self._coref.coref_propagation(output_coref)
            span_embeddings = self._coref.update_spans(output_coref,
                                                       span_embeddings,
                                                       coref_indices)

        # Make predictions and compute losses for each module
        if self._loss_weights['ner'] > 0:
            output_ner = self._ner(spans, span_mask, span_embeddings,
                                   sentence_lengths, ner_labels, metadata)

        if self._loss_weights['coref'] > 0:
            output_coref = self._coref.predict_labels(output_coref, metadata)

        if self._loss_weights['relation'] > 0:
            output_relation = self._relation(spans, span_mask, span_embeddings,
                                             sentence_lengths, relation_labels,
                                             metadata)

        if self._loss_weights['events'] > 0:
            output_events = self._events(trigger_spans, trigger_mask,
                                         trigger_embeddings, spans, span_mask,
                                         span_embeddings, text_mask,
                                         text_embeddings, sentence_lengths,
                                         trigger_labels, argument_labels,
                                         ner_labels, metadata)

        # Use `get` since there are some cases where the output dict won't have a loss - for
        # instance, when doing prediction.
        loss = (
            self._loss_weights['coref'] * output_coref.get("loss", 0) +
            self._loss_weights['ner'] * output_ner.get("loss", 0) +
            self._loss_weights['relation'] * output_relation.get("loss", 0) +
            self._loss_weights['events'] * output_events.get("loss", 0))

        # Multiply the loss by the weight multiplier for this document.
        weight = metadata.weight if metadata.weight is not None else 1.0
        loss *= torch.tensor(weight)

        output_dict = dict(coref=output_coref,
                           relation=output_relation,
                           ner=output_ner,
                           events=output_events)
        output_dict['loss'] = loss

        output_dict["metadata"] = metadata

        return output_dict

    def update_span_embeddings(self, span_embeddings, span_mask,
                               top_span_embeddings, top_span_mask,
                               top_span_indices):
        # TODO(Ulme) Speed this up by tensorizing

        new_span_embeddings = span_embeddings.clone()
        for sample_nr in range(len(top_span_mask)):
            for top_span_nr, span_nr in enumerate(top_span_indices[sample_nr]):
                if top_span_mask[sample_nr,
                                 top_span_nr] == 0 or span_mask[sample_nr,
                                                                span_nr] == 0:
                    break
                new_span_embeddings[sample_nr,
                                    span_nr] = top_span_embeddings[sample_nr,
                                                                   top_span_nr]
        return new_span_embeddings

    @overrides
    def make_output_human_readable(self, output_dict: Dict[str, torch.Tensor]):
        """
        Converts the list of spans and predicted antecedent indices into clusters
        of spans for each element in the batch.

        Parameters
        ----------
        output_dict : ``Dict[str, torch.Tensor]``, required.
            The result of calling :func:`forward` on an instance or batch of instances.

        Returns
        -------
        The same output dictionary, but with an additional ``clusters`` key:

        clusters : ``List[List[List[Tuple[int, int]]]]``
            A nested list, representing, for each instance in the batch, the list of clusters,
            which are in turn comprised of a list of (start, end) inclusive spans into the
            original document.
        """

        doc = copy.deepcopy(output_dict["metadata"])

        if self._loss_weights["coref"] > 0:
            # TODO(dwadden) Will need to get rid of the [0] when batch training is enabled.
            decoded_coref = self._coref.make_output_human_readable(
                output_dict["coref"])["predicted_clusters"][0]
            sentences = doc.sentences
            sentence_starts = [sent.sentence_start for sent in sentences]
            predicted_clusters = [
                document.Cluster(entry, i, sentences, sentence_starts)
                for i, entry in enumerate(decoded_coref)
            ]
            doc.predicted_clusters = predicted_clusters
            # TODO(dwadden) update the sentences with cluster information.

        if self._loss_weights["ner"] > 0:
            for predictions, sentence in zip(output_dict["ner"]["predictions"],
                                             doc):
                sentence.predicted_ner = predictions

        if self._loss_weights["relation"] > 0:
            for predictions, sentence in zip(
                    output_dict["relation"]["predictions"], doc):
                sentence.predicted_relations = predictions

        if self._loss_weights["events"] > 0:
            for predictions, sentence in zip(
                    output_dict["events"]["predictions"], doc):
                sentence.predicted_events = predictions

        return doc

    def get_metrics(self, reset: bool = False) -> Dict[str, float]:
        """
        Get all metrics from all modules. For the ones that shouldn't be displayed, prefix their
        keys with an underscore.
        """
        metrics_coref = self._coref.get_metrics(reset=reset)
        metrics_ner = self._ner.get_metrics(reset=reset)
        metrics_relation = self._relation.get_metrics(reset=reset)
        metrics_events = self._events.get_metrics(reset=reset)

        # Make sure that there aren't any conflicting names.
        metric_names = (list(metrics_coref.keys()) + list(metrics_ner.keys()) +
                        list(metrics_relation.keys()) +
                        list(metrics_events.keys()))
        assert len(set(metric_names)) == len(metric_names)
        all_metrics = dict(
            list(metrics_coref.items()) + list(metrics_ner.items()) +
            list(metrics_relation.items()) + list(metrics_events.items()))

        # If no list of desired metrics given, display them all.
        if self._display_metrics is None:
            return all_metrics
        # Otherwise only display the selected ones.
        res = {}
        for k, v in all_metrics.items():
            if k in self._display_metrics:
                res[k] = v
            else:
                new_k = "_" + k
                res[new_k] = v
        return res
예제 #25
0
class SrlE2e(Model):
    """

    # Parameters

    vocab : `Vocabulary`, required
        A Vocabulary, required in order to compute sizes for input/output projections.
    model : `Union[str, BertModel]`, required.
        A string describing the BERT model to load or an already constructed BertModel.
    initializer : `InitializerApplicator`, optional (default=`InitializerApplicator()`)
        Used to initialize the model parameters.
    label_smoothing : `float`, optional (default = `0.0`)
        Whether or not to use label smoothing on the labels when computing cross entropy loss.
    ignore_span_metric : `bool`, optional (default = `False`)
        Whether to calculate span loss, which is irrelevant when predicting BIO for Open Information Extraction.
    srl_eval_path : `str`, optional (default=`DEFAULT_SRL_EVAL_PATH`)
        The path to the srl-eval.pl script. By default, will use the srl-eval.pl included with allennlp,
        which is located at allennlp/tools/srl-eval.pl . If `None`, srl-eval.pl is not used.
    """
    def __init__(
        self,
        vocab: Vocabulary,
        bert_model: Union[str, BertModel],
        mention_feedforward: FeedForward,
        context_layer: Seq2SeqEncoder = None,
        embedding_dropout: float = 0.0,
        initializer: InitializerApplicator = InitializerApplicator(),
        max_span_width: int = 30,
        feature_size: int = 10,
        spans_per_word: float = 100,
        label_smoothing: float = None,
        ignore_span_metric: bool = False,
        srl_eval_path: str = DEFAULT_SRL_EVAL_PATH,
        **kwargs,
    ) -> None:
        super().__init__(vocab, **kwargs)

        if isinstance(bert_model, str):
            self.bert_model = BertModel.from_pretrained(bert_model)
        else:
            self.bert_model = bert_model

        self.num_classes = self.vocab.get_vocab_size("span_labels")
        if srl_eval_path is not None:
            # For the span based evaluation, we don't want to consider labels
            # for verb, because the verb index is provided to the model.
            self.span_metric = SrlEvalScorer(srl_eval_path,
                                             ignore_classes=["V"])
        else:
            self.span_metric = None
        self.tag_projection_layer = Linear(self.bert_model.config.hidden_size,
                                           self.num_classes)

        self.embedding_dropout = Dropout(p=embedding_dropout)
        self._label_smoothing = label_smoothing
        self.ignore_span_metric = ignore_span_metric

        self._mention_feedforward = TimeDistributed(mention_feedforward)
        self._mention_scorer = TimeDistributed(
            torch.nn.Linear(mention_feedforward.get_output_dim(), 1))

        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=self.bert_model.config.hidden_size)
        self.span_representation_dim = self._attentive_span_extractor.get_output_dim(
        )
        self._context_layer = context_layer
        if context_layer is not None:
            self._endpoint_span_extractor = EndpointSpanExtractor(
                context_layer.get_output_dim(),
                combination="x,y",
                num_width_embeddings=max_span_width,
                span_width_embedding_dim=feature_size,
                bucket_widths=False,
            )
            self.span_representation_dim = self._endpoint_span_extractor.get_output_dim(
            )

        self.hidden_layer = torch.nn.Sequential(
            torch.nn.Linear(self.span_representation_dim +
                            self.bert_model.config.hidden_size,
                            self.span_representation_dim,
                            bias=False), torch.nn.ReLU())
        self.output_layer = torch.nn.Linear(self.span_representation_dim,
                                            self.num_classes - 1,
                                            bias=False)

        self._max_span_width = max_span_width
        self._spans_per_word = spans_per_word
        self._ce_loss = torch.nn.CrossEntropyLoss(reduction='none')
        self._bce_loss = torch.nn.BCEWithLogitsLoss(reduction='none')
        initializer(self)

    def forward(  # type: ignore
        self,
        tokens: TextFieldTensors,
        verb_indicator: torch.Tensor,
        sentence_end: torch.LongTensor,
        spans: torch.LongTensor,
        span_labels: torch.LongTensor,
        metadata: List[Any],
        tags: torch.LongTensor = None,
    ):
        """
        # Parameters

        tokens : `TextFieldTensors`, required
            The output of `TextField.as_array()`, which should typically be passed directly to a
            `TextFieldEmbedder`. For this model, this must be a `SingleIdTokenIndexer` which
            indexes wordpieces from the BERT vocabulary.
        verb_indicator: `torch.LongTensor`, required.
            An integer `SequenceFeatureField` representation of the position of the verb
            in the sentence. This should have shape (batch_size, num_tokens) and importantly, can be
            all zeros, in the case that the sentence has no verbal predicate.
        tags : `torch.LongTensor`, optional (default = `None`)
            A torch tensor representing the sequence of integer gold class labels
            of shape `(batch_size, num_tokens)`
        metadata : `List[Dict[str, Any]]`, optional, (default = `None`)
            metadata containg the original words in the sentence, the verb to compute the
            frame for, and start offsets for converting wordpieces back to a sequence of words,
            under 'words', 'verb' and 'offsets' keys, respectively.

        # Returns

        An output dictionary consisting of:
        logits : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            unnormalised log probabilities of the tag classes.
        class_probabilities : `torch.FloatTensor`
            A tensor of shape `(batch_size, num_tokens, tag_vocab_size)` representing
            a distribution of the tag classes per word.
        loss : `torch.FloatTensor`, optional
            A scalar loss to be optimised.
        """
        mask = get_text_field_mask(tokens)
        start = time.time()
        bert_embeddings, _ = self.bert_model(
            input_ids=util.get_token_ids_from_text_field_tensors(tokens),
            # token_type_ids=verb_indicator,
            attention_mask=mask,
        )

        # Shape: (batch_size, num_spans)
        span_mask = (spans[:, :, 0] >= 0).squeeze(-1)
        # SpanFields return -1 when they are used as padding. As we do
        # some comparisons based on span widths when we attend over the
        # span representations that we generate from these indices, we
        # need them to be <= 0. This is only relevant in edge cases where
        # the number of spans we consider after the pruning stage is >= the
        # total number of spans, because in this case, it is possible we might
        # consider a masked span.
        # Shape: (batch_size, num_spans, 2)
        spans = F.relu(spans.float()).long()

        embedded_text_input = self.embedding_dropout(bert_embeddings)
        batch_size, sequence_length, _ = embedded_text_input.size()
        # Shape: (batch_size, num_spans, emebedding_size)
        attended_span_embeddings = self._attentive_span_extractor(
            bert_embeddings, spans)

        if self._context_layer is not None:
            contextualized_embeddings = self._context_layer(
                embedded_text_input, mask)
            # Shape: (batch_size, num_spans, 2 * encoding_dim + feature_size)
            endpoint_span_embeddings = self._endpoint_span_extractor(
                contextualized_embeddings, spans)

            # Shape: (batch_size, num_spans, emebedding_size + 2 * encoding_dim + feature_size)
            # span_embeddings = torch.cat([endpoint_span_embeddings, attended_span_embeddings], -1)
            span_embeddings = endpoint_span_embeddings
        else:
            span_embeddings = attended_span_embeddings

        # Prune based on mention scores.
        num_spans_to_keep = int(
            math.floor(self._spans_per_word * sequence_length))
        num_spans = spans.shape[1]
        num_spans_to_keep = min(num_spans_to_keep, num_spans)

        # Shape: (batch_size, num_spans)
        span_mention_scores = self._mention_scorer(
            self._mention_feedforward(span_embeddings)).squeeze(-1)
        # Shape: (batch_size, num_spans) for all 3 tensors
        top_span_mention_scores, top_span_mask, top_span_indices = util.masked_topk(
            span_mention_scores, span_mask, num_spans_to_keep)
        verb_index = verb_indicator.argmax(1).unsqueeze(1).unsqueeze(2).repeat(
            1, 1, embedded_text_input.shape[-1])
        verb_embeddings = torch.gather(embedded_text_input, 1, verb_index)
        assert len(
            verb_embeddings.shape) == 3 and verb_embeddings.shape[1] == 1
        verb_embeddings = verb_embeddings.squeeze(1)
        # print(verb_indicator.sum(1, keepdim=True) > 0)
        verb_embeddings = torch.where(
            (verb_indicator.sum(1, keepdim=True) > 0).repeat(
                1, verb_embeddings.shape[-1]), verb_embeddings,
            torch.zeros_like(verb_embeddings))
        # print(verb_embeddings)
        flat_top_span_indices = util.flatten_and_batch_shift_indices(
            top_span_indices, spans.shape[1])
        span_embeddings = util.batched_index_select(span_embeddings,
                                                    top_span_indices,
                                                    flat_top_span_indices)
        top_spans = util.batched_index_select(spans, top_span_indices,
                                              flat_top_span_indices)
        top_span_labels = util.batched_index_select(
            span_labels.unsqueeze(-1), top_span_indices,
            flat_top_span_indices).squeeze(-1)
        concatenated_span_embeddings = torch.cat(
            (span_embeddings, verb_embeddings.unsqueeze(1).repeat(
                1, span_embeddings.shape[1], 1)),
            dim=2)
        # print(concatenated_span_embeddings[:,:,:])
        hidden = self.hidden_layer(concatenated_span_embeddings)
        # print(hidden[1,:,:])
        # print(top_span_indices)
        # print([[span_mention_scores[i,top_span_indices[i,j]].item() for j in range(top_span_indices.shape[1])] for i in range(top_span_labels.shape[0])])
        # print(top_span_mention_scores, self.vocab.get_token_index("O", namespace="span_labels"))
        predictions = self.output_layer(hidden)
        # predictions += top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1)
        predictions = torch.cat(
            (torch.zeros_like(predictions[:, :, :1]), predictions), dim=-1)
        # print(top_span_mention_scores.unsqueeze(-1).repeat(1, 1, self.num_classes-1))

        output_dict = {}
        # We need to retain the mask in the output dictionary
        # so that we can crop the sequences to remove padding
        # when we do viterbi inference in self.make_output_human_readable.
        output_dict["mask"] = mask
        # We add in the offsets here so we can compute the un-wordpieced tags.
        words, verbs, offsets = zip(*[(x["words"], x["verb"], x["offsets"])
                                      for x in metadata])
        output_dict["words"] = list(words)
        output_dict["verb"] = list(verbs)
        output_dict["wordpiece_offsets"] = list(offsets)

        if tags is not None:
            loss = (self._ce_loss(predictions.view(-1, predictions.shape[-1]),
                                  top_span_labels.view(-1)) *
                    top_span_mask.float().view(-1)
                    ).sum() / top_span_mask.float().sum()
            # print(top_span_labels)
            # print(predictions.argmax(-1))
            if not self.ignore_span_metric and self.span_metric is not None and not self.training:
                batch_verb_indices = [
                    example_metadata["verb_index"]
                    for example_metadata in metadata
                ]
                batch_sentences = [
                    example_metadata["words"] for example_metadata in metadata
                ]
                # Get the BIO tags from make_output_human_readable()
                # TODO (nfliu): This is kind of a hack, consider splitting out part
                # of make_output_human_readable() to a separate function.
                batch_bio_predicted_tags = self.get_tags(
                    top_spans, predictions, mask.shape[1], top_span_mask,
                    output_dict)
                from allennlp_models.structured_prediction.models.srl import (
                    convert_bio_tags_to_conll_format, )

                batch_conll_predicted_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_predicted_tags
                ]
                batch_bio_gold_tags = [
                    example_metadata["gold_tags"]
                    for example_metadata in metadata
                ]
                # print('G', batch_bio_gold_tags)
                batch_conll_gold_tags = [
                    convert_bio_tags_to_conll_format(tags)
                    for tags in batch_bio_gold_tags
                ]
                self.span_metric(
                    batch_verb_indices,
                    batch_sentences,
                    batch_conll_predicted_tags,
                    batch_conll_gold_tags,
                )
            output_dict["loss"] = loss
        return output_dict

    def get_tags(self, spans, logits, sequence_length, span_mask, output_dict):
        predicted_tag_ids = logits.argmax(2)
        predicted_tags = []
        for i in range(spans.shape[0]):
            sequence = ["O" for _ in range(sequence_length)]
            for j in range(spans.shape[1]):
                if span_mask[i, j].item() == 0:
                    continue
                tag = predicted_tag_ids[i, j].item()
                if tag != self.vocab.get_token_index("O",
                                                     namespace="span_labels"):
                    start = spans[i, j, 0].item()
                    end = spans[i, j, 1].item()
                    if all([el == "O" for el in sequence[start:end + 1]]):
                        tag = self.vocab.get_token_from_index(
                            tag, namespace="span_labels")
                        sequence[start] = "B-" + tag
                        for index in range(start + 1, end + 1):
                            sequence[index] = "I-" + tag
            predicted_tags.append(
                [sequence[ind] for ind in output_dict["wordpiece_offsets"][i]])
        print(predicted_tags)
        return predicted_tags

    @overrides
    def make_output_human_readable(
            self, output_dict: Dict[str,
                                    torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Does constrained viterbi decoding on class probabilities output in :func:`forward`.  The
        constraint simply specifies that the output tags must be a valid BIO sequence.  We add a
        `"tags"` key to the dictionary with the result.

        NOTE: First, we decode a BIO sequence on top of the wordpieces. This is important; viterbi
        decoding produces low quality output if you decode on top of word representations directly,
        because the model gets confused by the 'missing' positions (which is sensible as it is trained
        to perform tagging on wordpieces, not words).

        Secondly, it's important that the indices we use to recover words from the wordpieces are the
        start_offsets (i.e offsets which correspond to using the first wordpiece of words which are
        tokenized into multiple wordpieces) as otherwise, we might get an ill-formed BIO sequence
        when we select out the word tags from the wordpiece tags. This happens in the case that a word
        is split into multiple word pieces, and then we take the last tag of the word, which might
        correspond to, e.g, I-V, which would not be allowed as it is not preceeded by a B tag.
        """
        all_predictions = output_dict["class_probabilities"]
        sequence_lengths = get_lengths_from_binary_sequence_mask(
            output_dict["mask"]).data.tolist()

        if all_predictions.dim() == 3:
            predictions_list = [
                all_predictions[i].detach().cpu()
                for i in range(all_predictions.size(0))
            ]
        else:
            predictions_list = [all_predictions]
        wordpiece_tags = []
        word_tags = []
        transition_matrix = self.get_viterbi_pairwise_potentials()
        start_transitions = self.get_start_transitions()
        # **************** Different ********************
        # We add in the offsets here so we can compute the un-wordpieced tags.
        for predictions, length, offsets in zip(
                predictions_list, sequence_lengths,
                output_dict["wordpiece_offsets"]):
            max_likelihood_sequence, _ = viterbi_decode(
                predictions[:length],
                transition_matrix,
                allowed_start_transitions=start_transitions)
            tags = [
                self.vocab.get_token_from_index(x, namespace="labels")
                for x in max_likelihood_sequence
            ]

            wordpiece_tags.append(tags)
            word_tags.append([tags[i] for i in offsets])
        output_dict["wordpiece_tags"] = wordpiece_tags
        output_dict["tags"] = word_tags
        return output_dict

    def get_metrics(self, reset: bool = False):
        if self.ignore_span_metric:
            # Return an empty dictionary if ignoring the
            # span metric
            return {}

        else:
            metric_dict = self.span_metric.get_metric(reset=reset)

            # This can be a lot of metrics, as there are 3 per class.
            # we only really care about the overall metrics, so we filter for them here.
            return {x: y for x, y in metric_dict.items() if "overall" in x}

    def get_viterbi_pairwise_potentials(self):
        """
        Generate a matrix of pairwise transition potentials for the BIO labels.
        The only constraint implemented here is that I-XXX labels must be preceded
        by either an identical I-XXX tag or a B-XXX tag. In order to achieve this
        constraint, pairs of labels which do not satisfy this constraint have a
        pairwise potential of -inf.

        # Returns

        transition_matrix : `torch.Tensor`
            A `(num_labels, num_labels)` matrix of pairwise potentials.
        """
        all_labels = self.vocab.get_index_to_token_vocabulary("labels")
        num_labels = len(all_labels)
        transition_matrix = torch.zeros([num_labels, num_labels])

        for i, previous_label in all_labels.items():
            for j, label in all_labels.items():
                # I labels can only be preceded by themselves or
                # their corresponding B tag.
                if i != j and label[
                        0] == "I" and not previous_label == "B" + label[1:]:
                    transition_matrix[i, j] = float("-inf")
        return transition_matrix

    def get_start_transitions(self):
        """
        In the BIO sequence, we cannot start the sequence with an I-XXX tag.
        This transition sequence is passed to viterbi_decode to specify this constraint.

        # Returns

        start_transitions : `torch.Tensor`
            The pairwise potentials between a START token and
            the first token of the sequence.
        """
        all_labels = self.vocab.get_index_to_token_vocabulary("labels")
        num_labels = len(all_labels)

        start_transitions = torch.zeros(num_labels)

        for i, label in all_labels.items():
            if label[0] == "I":
                start_transitions[i] = float("-inf")

        return start_transitions

    default_predictor = "semantic_role_labeling"
예제 #26
0
    def __init__(self, bert_hidden_size: int, linear_hidden_size: int,
                 dist_embed_dim: int, token_dist_ratio: int, use_layers: list):
        super().__init__()
        self.bert_hidden_size = bert_hidden_size
        self.embed_dim = dist_embed_dim
        cat_contexts = 3
        self.use_layers = len(use_layers)
        self.span_extractor = SelfAttentiveSpanExtractor(bert_hidden_size *
                                                         self.use_layers)
        # self.span_extractor = EndpointSpanExtractor(
        #     bert_hidden_size, "x,y,x*y"
        # )
        self.buckets = [-8, -4, -2, -1, 1, 2, 3, 4, 5, 8, 16, 32, 64]

        # self.fc_ABshare = nn.Sequential(
        #     nn.Linear(bert_hidden_size * self.use_layers, linear_hidden_size),
        #     nn.ReLU(),
        # )
        #
        # self.fc_Pshare = nn.Sequential(
        #     nn.Linear(bert_hidden_size * self.use_layers, linear_hidden_size),
        #     nn.ReLU(),
        # )

        self.fc_score = nn.Sequential(
            nn.BatchNorm1d(bert_hidden_size * cat_contexts * self.use_layers),
            nn.Dropout(0.1),
            nn.Linear(bert_hidden_size * cat_contexts * self.use_layers,
                      linear_hidden_size), nn.ReLU(),
            nn.BatchNorm1d(linear_hidden_size), nn.Dropout(0.5),
            nn.Linear(linear_hidden_size, dist_embed_dim * token_dist_ratio))
        self.fc_s = nn.Sequential(
            nn.BatchNorm1d(dist_embed_dim * (token_dist_ratio + 1)),
            nn.Dropout(0.5),
            nn.Linear(dist_embed_dim * (token_dist_ratio + 1),
                      dist_embed_dim * (token_dist_ratio + 1)),
            nn.ReLU(),
        )
        self.fc_final = nn.Sequential(
            nn.Linear(dist_embed_dim * (token_dist_ratio + 1) * 2 + 2, 3))
        self.dist_embed = nn.Embedding(len(self.buckets) + 1,
                                       embedding_dim=dist_embed_dim)

        for i, module in enumerate(self.fc_score):
            if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)
                print("Initing batchnorm")
            elif isinstance(module, nn.Linear):
                if getattr(module, "weight_v", None) is not None:
                    nn.init.uniform_(module.weight_g, 0, 1)
                    nn.init.kaiming_normal_(module.weight_v)
                    print("Initing linear with weight normalization")
                    # assert model[i].weight_g is not None
                else:
                    nn.init.kaiming_normal_(module.weight)
                    print("Initing linear")
                nn.init.constant_(module.bias, 0)

        for i, module in enumerate(self.fc_final):
            if isinstance(module, (nn.BatchNorm1d, nn.BatchNorm2d)):
                nn.init.constant_(module.weight, 1)
                nn.init.constant_(module.bias, 0)
                print("Initing batchnorm")
            elif isinstance(module, nn.Linear):
                if getattr(module, "weight_v", None) is not None:
                    nn.init.uniform_(module.weight_g, 0, 1)
                    nn.init.kaiming_normal_(module.weight_v)
                    print("Initing linear with weight normalization")
                    # assert model[i].weight_g is not None
                else:
                    nn.init.kaiming_normal_(module.weight)
                    print("Initing linear")
                nn.init.constant_(module.bias, 0)
예제 #27
0
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        context_layer: Seq2SeqEncoder,
        mention_feedforward: FeedForward,
        antecedent_feedforward: FeedForward,
        feature_size: int,
        max_span_width: int,
        spans_per_word: float,
        max_antecedents: int,
        coarse_to_fine: bool = False,
        inference_order: int = 1,
        lexical_dropout: float = 0.2,
        initializer: InitializerApplicator = InitializerApplicator(),
        **kwargs
    ) -> None:
        super().__init__(vocab, **kwargs)

        self._text_field_embedder = text_field_embedder
        self._context_layer = context_layer
        self._mention_feedforward = TimeDistributed(mention_feedforward)
        self._mention_scorer = TimeDistributed(
            torch.nn.Linear(mention_feedforward.get_output_dim(), 1)
        )
        self._antecedent_feedforward = TimeDistributed(antecedent_feedforward)
        self._antecedent_scorer = TimeDistributed(
            torch.nn.Linear(antecedent_feedforward.get_output_dim(), 1)
        )

        self._endpoint_span_extractor = EndpointSpanExtractor(
            context_layer.get_output_dim(),
            combination="x,y",
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False,
        )
        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=text_field_embedder.get_output_dim()
        )

        # 10 possible distance buckets.
        self._num_distance_buckets = 10
        self._distance_embedding = Embedding(
            embedding_dim=feature_size, num_embeddings=self._num_distance_buckets
        )

        self._max_span_width = max_span_width
        self._spans_per_word = spans_per_word
        self._max_antecedents = max_antecedents

        self._coarse_to_fine = coarse_to_fine
        if self._coarse_to_fine:
            self._coarse2fine_scorer = torch.nn.Linear(
                mention_feedforward.get_input_dim(), mention_feedforward.get_input_dim()
            )
        self._inference_order = inference_order
        if self._inference_order > 1:
            self._span_updating_gated_sum = GatedSum(mention_feedforward.get_input_dim())

        self._mention_recall = MentionRecall()
        self._conll_coref_scores = ConllCorefScores()
        if lexical_dropout > 0:
            self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
        else:
            self._lexical_dropout = lambda x: x
        initializer(self)
예제 #28
0
파일: dygie.py 프로젝트: GillesJ/dygiepp
    def __init__(
            self,
            vocab: Vocabulary,
            embedder: TextFieldEmbedder,
            context_layer: Seq2SeqEncoder,
            modules,  # TODO(dwadden) Add type.
            feature_size: int,
            max_span_width: int,
            max_trigger_span_width: int,
            target_task: str,
            feedforward_params: Dict[str, Union[int, float]],
            loss_weights: Dict[str, float],
            lexical_dropout: float = 0.2,
            use_attentive_span_extractor: bool = False,
            initializer: InitializerApplicator = InitializerApplicator(),
            module_initializer: InitializerApplicator = InitializerApplicator(
            ),
            regularizer: Optional[RegularizerApplicator] = None,
            display_metrics: List[str] = None) -> None:
        super(DyGIE, self).__init__(vocab, regularizer)

        ####################

        # Create span extractor.
        self._endpoint_span_extractor = EndpointSpanExtractor(
            context_layer.get_output_dim(),
            combination="x,y",
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False)
        self._endpoint_trigger_span_extractor = EndpointSpanExtractor(
            context_layer.get_output_dim(),
            combination="x,y",
            num_width_embeddings=max_trigger_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False)

        ####################
        if lexical_dropout > 0:
            self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
        else:
            self._lexical_dropout = lambda x: x

        if use_attentive_span_extractor:
            self._attentive_span_extractor = SelfAttentiveSpanExtractor(
                input_dim=context_layer.get_output_dim())
        else:
            self._attentive_span_extractor = None

        # Set parameters.
        self._embedder = embedder
        self._context_layer = context_layer
        self._loss_weights = loss_weights
        self._max_span_width = max_span_width
        self._max_trigger_span_width = max_trigger_span_width
        self._display_metrics = self._get_display_metrics(target_task)

        trigger_emb_dim = self._endpoint_trigger_span_extractor.get_output_dim(
        )
        span_emb_dim = self._endpoint_span_extractor.get_output_dim()

        if self._attentive_span_extractor is not None:
            span_emb_dim += self._attentive_span_extractor.get_output_dim()
            trigger_emb_dim += self._attentive_span_extractor.get_output_dim()

        ####################

        # Create submodules.

        modules = Params(modules)

        # Helper function to create feedforward networks.
        def make_feedforward(input_dim):
            return FeedForward(input_dim=input_dim,
                               num_layers=feedforward_params["num_layers"],
                               hidden_dims=feedforward_params["hidden_dims"],
                               activations=torch.nn.ReLU(),
                               dropout=feedforward_params["dropout"])

        # Submodules

        self._ner = NERTagger.from_params(vocab=vocab,
                                          make_feedforward=make_feedforward,
                                          span_emb_dim=span_emb_dim,
                                          feature_size=feature_size,
                                          params=modules.pop("ner"))

        self._coref = CorefResolver.from_params(
            vocab=vocab,
            make_feedforward=make_feedforward,
            span_emb_dim=span_emb_dim,
            feature_size=feature_size,
            params=modules.pop("coref"))

        self._relation = RelationExtractor.from_params(
            vocab=vocab,
            make_feedforward=make_feedforward,
            span_emb_dim=span_emb_dim,
            feature_size=feature_size,
            params=modules.pop("relation"))

        self._events = EventExtractor.from_params(
            vocab=vocab,
            make_feedforward=make_feedforward,
            text_emb_dim=self._embedder.get_output_dim(),
            trigger_emb_dim=trigger_emb_dim,
            span_emb_dim=span_emb_dim,
            feature_size=feature_size,
            params=modules.pop("events"))

        ####################

        # Initialize text embedder and all submodules
        for module in [self._ner, self._coref, self._relation, self._events]:
            module_initializer(module)

        initializer(self)