Example #1
0
    def test_inference_no_head(self):
        # ideally we want to test this with the weights of tapas_inter_masklm_base_reset,
        # but since it's not straightforward to do this with the TF 1 implementation, we test it with
        # the weights of the WTQ base model (i.e. tapas_wtq_wikisql_sqa_inter_masklm_base_reset)
        model = TFTapasModel.from_pretrained("google/tapas-base-finetuned-wtq")
        tokenizer = self.default_tokenizer
        table, queries = prepare_tapas_single_inputs_for_inference()
        inputs = tokenizer(table=table, queries=queries, return_tensors="tf")
        outputs = model(**inputs)

        # test the sequence output
        expected_slice = tf.constant([[
            [-0.141581565, -0.599805772, 0.747186482],
            [-0.143664181, -0.602008104, 0.749218345],
            [-0.15169853, -0.603363097, 0.741370678],
        ]])
        tf.debugging.assert_near(outputs.last_hidden_state[:, :3, :3],
                                 expected_slice,
                                 atol=0.0005)

        # test the pooled output
        expected_slice = tf.constant(
            [[0.987518311, -0.970520139, -0.994303405]])

        tf.debugging.assert_near(outputs.pooler_output[:, :3],
                                 expected_slice,
                                 atol=0.0005)
Example #2
0
    def create_and_check_model(
        self,
        config,
        input_ids,
        input_mask,
        token_type_ids,
        sequence_labels,
        token_labels,
        labels,
        numeric_values,
        numeric_values_scale,
        float_answer,
        aggregation_labels,
    ):
        model = TFTapasModel(config=config)

        inputs = {
            "input_ids": input_ids,
            "attention_mask": input_mask,
            "token_type_ids": token_type_ids,
        }
        result = model(inputs)
        inputs.pop("attention_mask")
        result = model(inputs)
        inputs.pop("token_type_ids")
        result = model(inputs)

        self.parent.assertEqual(
            result.last_hidden_state.shape,
            (self.batch_size, self.seq_length, self.hidden_size))
        self.parent.assertEqual(result.pooler_output.shape,
                                (self.batch_size, self.hidden_size))