def test_encoder(self):
        embedding = torch.rand(5, 50, 7)
        encoder = ClsPooler(embedding_dim=7)
        pooled = encoder(embedding, mask=None)

        assert list(pooled.size()) == [5, 7]
        numpy.testing.assert_array_almost_equal(embedding[:, 0], pooled)
Beispiel #2
0
    def test_cls_at_end(self):
        embedding = torch.arange(20).reshape(5,
                                             4).unsqueeze(-1).expand(5, 4, 7)
        mask = torch.LongTensor([[1, 1, 1, 1], [1, 1, 1, 0], [1, 1, 1, 1],
                                 [1, 0, 0, 0], [1, 1, 0, 0]])
        expected = torch.LongTensor([3, 6, 11, 12,
                                     17]).unsqueeze(-1).expand(5, 7)

        encoder = ClsPooler(embedding_dim=7, cls_is_last_token=True)
        pooled = encoder(embedding, mask=mask)

        assert list(pooled.size()) == [5, 7]
        numpy.testing.assert_array_almost_equal(expected, pooled)
Beispiel #3
0
    def init_model(self) -> Model:
        """build the model

        Args:
            vocab (Vocabulary): the vocabulary of corpus

        Returns:
            Model: the final models
        """
        bert_text_field_embedder = PretrainedTransformerEmbedder(
            model_name=self.config.model_name)
        tagger = BasicClassifier(
            vocab=self.vocab,
            text_field_embedder=BasicTextFieldEmbedder(
                token_embedders={'tokens': bert_text_field_embedder}),
            seq2vec_encoder=ClsPooler(
                embedding_dim=bert_text_field_embedder.get_output_dim()),
        )
        tagger.to(device=self.config.device)
        return tagger
Beispiel #4
0
    def __init__(self, input_dim: int, **kwargs):
        super().__init__(input_dim=input_dim * 4, **kwargs)

        # Gets the [CLS] token from BERT
        self._seq2vec_encoder = ClsPooler(embedding_dim=input_dim)