def to_method_object(self):
     """Convert the enum to an instance of `BaselineMethod`."""
     if self == self.TF_IDF:
         return keyword_based.TfIdfMethod()
     elif self == self.BM25:
         return keyword_based.BM25Method()
     elif self == self.USE_SIM:
         return vector_based.VectorSimilarityMethod(
             encoder=vector_based.TfHubEncoder(
                 "https://tfhub.dev/google/"
                 "universal-sentence-encoder/2"))
     elif self == self.USE_LARGE_SIM:
         return vector_based.VectorSimilarityMethod(
             encoder=vector_based.TfHubEncoder(
                 "https://tfhub.dev/google/"
                 "universal-sentence-encoder-large/3"))
     elif self == self.ELMO_SIM:
         return vector_based.VectorSimilarityMethod(
             encoder=vector_based.TfHubEncoder(
                 "https://tfhub.dev/google/elmo/1"))
     elif self == self.USE_MAP:
         return vector_based.VectorMappingMethod(
             encoder=vector_based.TfHubEncoder(
                 "https://tfhub.dev/google/"
                 "universal-sentence-encoder/2"))
     elif self == self.USE_LARGE_MAP:
         return vector_based.VectorMappingMethod(
             encoder=vector_based.TfHubEncoder(
                 "https://tfhub.dev/google/"
                 "universal-sentence-encoder-large/3"))
     elif self == self.ELMO_MAP:
         return vector_based.VectorMappingMethod(
             encoder=vector_based.TfHubEncoder(
                 "https://tfhub.dev/google/elmo/1"))
     elif self == self.BERT_SMALL_SIM:
         return vector_based.VectorSimilarityMethod(
             encoder=vector_based.BERTEncoder(
                 "https://tfhub.dev/google/"
                 "bert_uncased_L-12_H-768_A-12/1"))
     elif self == self.BERT_SMALL_MAP:
         return vector_based.VectorMappingMethod(
             encoder=vector_based.BERTEncoder(
                 "https://tfhub.dev/google/"
                 "bert_uncased_L-12_H-768_A-12/1"))
     elif self == self.BERT_LARGE_SIM:
         return vector_based.VectorSimilarityMethod(
             encoder=vector_based.BERTEncoder(
                 "https://tfhub.dev/google/"
                 "bert_uncased_L-24_H-1024_A-16/1"))
     elif self == self.BERT_LARGE_MAP:
         return vector_based.VectorMappingMethod(
             encoder=vector_based.BERTEncoder(
                 "https://tfhub.dev/google/"
                 "bert_uncased_L-24_H-1024_A-16/1"))
     raise ValueError("Unknown method {}".format(self))
示例#2
0
    def test_encode(self, mock_module_cls):
        def mock_module(inputs=None, signature=None, as_dict=None):
            self.assertTrue(as_dict)
            if signature == "tokens":
                self.assertEqual({'input_mask', 'input_ids', 'segment_ids'},
                                 inputs.viewkeys())
                batch_size = tf.shape(inputs['input_ids'])[0]
                seq_len = tf.shape(inputs['input_ids'])[1]
                return {'sequence_output': tf.ones([batch_size, seq_len, 3])}
            self.assertEqual("tokenization_info", signature)
            return {
                'do_lower_case': tf.constant(True),
                'vocab_file': tf.constant(self.vocab_file),
            }

        mock_module_cls.return_value = mock_module

        encoder = vector_based.BERTEncoder("test_uri")
        self.assertEqual([(("test_uri", ), {
            'trainable': False
        })] * 2, mock_module_cls.call_args_list)

        # Final encodings will just be the count of the tokens in each
        # sentence, repeated 3 times.
        encodings = encoder.encode(["hello"])
        np.testing.assert_allclose([[3, 3, 3]], encodings)

        encodings = encoder.encode(["hello", "hello hi"])
        np.testing.assert_allclose([[3, 3, 3], [4, 4, 4]], encodings)