def test_gated_sum_can_run_forward(self):
        a = torch.FloatTensor([1, 2, 3, 4, 5])
        b = -a + 0.1
        weight_value = 2
        gate_value = torch.sigmoid(torch.FloatTensor([1]))
        expected = gate_value * a + (1 - gate_value) * b

        with torch.no_grad():  # because we want to change the weight
            gated_sum = GatedSum(a.size(-1))
            gated_sum._gate.weight *= 0
            gated_sum._gate.weight += weight_value
            gated_sum._gate.bias *= 0

            out = gated_sum(a, b)
            numpy.testing.assert_almost_equal(expected.data.numpy(),
                                              out.data.numpy(),
                                              decimal=5)

        with pytest.raises(ValueError):
            GatedSum(a.size(-1))(a, b.unsqueeze(0))

        with pytest.raises(ValueError):
            GatedSum(100)(a, b)
    def __init__(
        self,
        vocab: Vocabulary,
        text_field_embedder: TextFieldEmbedder,
        context_layer: Seq2SeqEncoder,
        mention_feedforward: FeedForward,
        antecedent_feedforward: FeedForward,
        feature_size: int,
        max_span_width: int,
        spans_per_word: float,
        max_antecedents: int,
        coarse_to_fine: bool = False,
        inference_order: int = 1,
        lexical_dropout: float = 0.2,
        initializer: InitializerApplicator = InitializerApplicator(),
        **kwargs
    ) -> None:
        super().__init__(vocab, **kwargs)

        self._text_field_embedder = text_field_embedder
        self._context_layer = context_layer
        self._mention_feedforward = TimeDistributed(mention_feedforward)
        self._mention_scorer = TimeDistributed(
            torch.nn.Linear(mention_feedforward.get_output_dim(), 1)
        )
        self._antecedent_feedforward = TimeDistributed(antecedent_feedforward)
        self._antecedent_scorer = TimeDistributed(
            torch.nn.Linear(antecedent_feedforward.get_output_dim(), 1)
        )

        self._endpoint_span_extractor = EndpointSpanExtractor(
            context_layer.get_output_dim(),
            combination="x,y",
            num_width_embeddings=max_span_width,
            span_width_embedding_dim=feature_size,
            bucket_widths=False,
        )
        self._attentive_span_extractor = SelfAttentiveSpanExtractor(
            input_dim=text_field_embedder.get_output_dim()
        )

        # 10 possible distance buckets.
        self._num_distance_buckets = 10
        self._distance_embedding = Embedding(
            embedding_dim=feature_size, num_embeddings=self._num_distance_buckets
        )

        self._max_span_width = max_span_width
        self._spans_per_word = spans_per_word
        self._max_antecedents = max_antecedents

        self._coarse_to_fine = coarse_to_fine
        if self._coarse_to_fine:
            self._coarse2fine_scorer = torch.nn.Linear(
                mention_feedforward.get_input_dim(), mention_feedforward.get_input_dim()
            )
        self._inference_order = inference_order
        if self._inference_order > 1:
            self._span_updating_gated_sum = GatedSum(mention_feedforward.get_input_dim())

        self._mention_recall = MentionRecall()
        self._conll_coref_scores = ConllCorefScores()
        if lexical_dropout > 0:
            self._lexical_dropout = torch.nn.Dropout(p=lexical_dropout)
        else:
            self._lexical_dropout = lambda x: x
        initializer(self)
 def test_input_output_dim(self):
     dim = 77
     gated_sum = GatedSum(dim)
     numpy.testing.assert_equal(gated_sum.get_input_dim(), dim)
     numpy.testing.assert_equal(gated_sum.get_output_dim(), dim)