示例#1
0
    def test_inference_no_head(self):
        model = TFDPRQuestionEncoder.from_pretrained(
            "facebook/dpr-question_encoder-single-nq-base")

        input_ids = tf.constant(
            [[101, 7592, 1010, 2003, 2026, 3899, 10140, 1029,
              102]])  # [CLS] hello, is my dog cute? [SEP]
        output = model(input_ids)[0]  # embedding shape = (1, 768)
        # compare the actual values for a slice.
        expected_slice = tf.constant([[
            0.03236253,
            0.12753335,
            0.16818509,
            0.00279786,
            0.3896933,
            0.24264945,
            0.2178971,
            -0.02335227,
            -0.08481959,
            -0.14324117,
        ]])
        self.assertTrue(
            numpy.allclose(output[:, :10].numpy(),
                           expected_slice.numpy(),
                           atol=1e-4))
示例#2
0
 def create_and_check_dpr_question_encoder(
     self, config, input_ids, token_type_ids, input_mask, sequence_labels, token_labels, choice_labels
 ):
     model = TFDPRQuestionEncoder(config=config)
     result = model(input_ids, attention_mask=input_mask, token_type_ids=token_type_ids)
     result = model(input_ids, token_type_ids=token_type_ids)
     result = model(input_ids)
     self.parent.assertEqual(result.pooler_output.shape, (self.batch_size, self.projection_dim or self.hidden_size))
示例#3
0
    def test_model_from_pretrained(self):
        for model_name in TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            model = TFDPRContextEncoder.from_pretrained(model_name)
            self.assertIsNotNone(model)

        for model_name in TF_DPR_CONTEXT_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            model = TFDPRContextEncoder.from_pretrained(model_name)
            self.assertIsNotNone(model)

        for model_name in TF_DPR_QUESTION_ENCODER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            model = TFDPRQuestionEncoder.from_pretrained(model_name)
            self.assertIsNotNone(model)

        for model_name in TF_DPR_READER_PRETRAINED_MODEL_ARCHIVE_LIST[:1]:
            model = TFDPRReader.from_pretrained(model_name)
            self.assertIsNotNone(model)