Exemplo n.º 1
0
 def test_loads_tf_encoder(self):
     """Tests that TF-based model can be loaded"""
     from transformers import TFAutoModelForPreTraining
     model = "bert-base-uncased"
     with patch.object(TFAutoModelForPreTraining,
                       'from_pretrained',
                       return_value=MockTFModel(model)):
         encoder = TransformerTFEncoder(pretrained_model_name_or_path=model)
         encoded_batch = encoder.encode(self.texts)
         assert encoded_batch.shape == (2, 768)
Exemplo n.º 2
0
 def _get_encoder(self, metas):
     return TransformerTFEncoder(
         polling_strategy='max',
         pretrained_model_name_or_path='bert-base-uncased',
         metas=metas)
Exemplo n.º 3
0
 def _get_encoder(self, metas):
     return TransformerTFEncoder(
         pretrained_model_name_or_path='xlnet-base-cased',
         metas=metas)