def to_method_object(self): """Convert the enum to an instance of `BaselineMethod`.""" if self == self.TF_IDF: return keyword_based.TfIdfMethod() elif self == self.BM25: return keyword_based.BM25Method() elif self == self.USE_SIM: return vector_based.VectorSimilarityMethod( encoder=vector_based.TfHubEncoder( "https://tfhub.dev/google/" "universal-sentence-encoder/2")) elif self == self.USE_LARGE_SIM: return vector_based.VectorSimilarityMethod( encoder=vector_based.TfHubEncoder( "https://tfhub.dev/google/" "universal-sentence-encoder-large/3")) elif self == self.ELMO_SIM: return vector_based.VectorSimilarityMethod( encoder=vector_based.TfHubEncoder( "https://tfhub.dev/google/elmo/1")) elif self == self.USE_MAP: return vector_based.VectorMappingMethod( encoder=vector_based.TfHubEncoder( "https://tfhub.dev/google/" "universal-sentence-encoder/2")) elif self == self.USE_LARGE_MAP: return vector_based.VectorMappingMethod( encoder=vector_based.TfHubEncoder( "https://tfhub.dev/google/" "universal-sentence-encoder-large/3")) elif self == self.ELMO_MAP: return vector_based.VectorMappingMethod( encoder=vector_based.TfHubEncoder( "https://tfhub.dev/google/elmo/1")) elif self == self.BERT_SMALL_SIM: return vector_based.VectorSimilarityMethod( encoder=vector_based.BERTEncoder( "https://tfhub.dev/google/" "bert_uncased_L-12_H-768_A-12/1")) elif self == self.BERT_SMALL_MAP: return vector_based.VectorMappingMethod( encoder=vector_based.BERTEncoder( "https://tfhub.dev/google/" "bert_uncased_L-12_H-768_A-12/1")) elif self == self.BERT_LARGE_SIM: return vector_based.VectorSimilarityMethod( encoder=vector_based.BERTEncoder( "https://tfhub.dev/google/" "bert_uncased_L-24_H-1024_A-16/1")) elif self == self.BERT_LARGE_MAP: return vector_based.VectorMappingMethod( encoder=vector_based.BERTEncoder( "https://tfhub.dev/google/" "bert_uncased_L-24_H-1024_A-16/1")) raise ValueError("Unknown method {}".format(self))
def test_encode(self, mock_module_cls): def mock_module(inputs=None, signature=None, as_dict=None): self.assertTrue(as_dict) if signature == "tokens": self.assertEqual({'input_mask', 'input_ids', 'segment_ids'}, inputs.viewkeys()) batch_size = tf.shape(inputs['input_ids'])[0] seq_len = tf.shape(inputs['input_ids'])[1] return {'sequence_output': tf.ones([batch_size, seq_len, 3])} self.assertEqual("tokenization_info", signature) return { 'do_lower_case': tf.constant(True), 'vocab_file': tf.constant(self.vocab_file), } mock_module_cls.return_value = mock_module encoder = vector_based.BERTEncoder("test_uri") self.assertEqual([(("test_uri", ), { 'trainable': False })] * 2, mock_module_cls.call_args_list) # Final encodings will just be the count of the tokens in each # sentence, repeated 3 times. encodings = encoder.encode(["hello"]) np.testing.assert_allclose([[3, 3, 3]], encodings) encodings = encoder.encode(["hello", "hello hi"]) np.testing.assert_allclose([[3, 3, 3], [4, 4, 4]], encodings)