Ejemplo n.º 1
0
    def test_values_override(self):
        """
        Ensure the use_past variable correctly set the `use_cache` value in model's configuration
        """
        for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
            with self.subTest(name):

                # without past
                onnx_config_default = OnnxConfigWithPast.from_model_config(
                    config())
                self.assertIsNotNone(onnx_config_default.values_override,
                                     "values_override should not be None")
                self.assertIn("use_cache", onnx_config_default.values_override,
                              "use_cache should be present")
                self.assertFalse(
                    onnx_config_default.values_override["use_cache"],
                    "use_cache should be False if not using past")

                # with past
                onnx_config_default = OnnxConfigWithPast.with_past(config())
                self.assertIsNotNone(onnx_config_default.values_override,
                                     "values_override should not be None")
                self.assertIn("use_cache", onnx_config_default.values_override,
                              "use_cache should be present")
                self.assertTrue(
                    onnx_config_default.values_override["use_cache"],
                    "use_cache should be False if not using past")
Ejemplo n.º 2
0
    def test_use_past(self):
        """
        Ensure the use_past variable is correctly being set
        """
        for name, config in OnnxConfigWithPastTestCaseV2.SUPPORTED_WITH_PAST_CONFIGS:
            with self.subTest(name):
                self.assertFalse(
                    OnnxConfigWithPast.from_model_config(config()).use_past,
                    "OnnxConfigWithPast.from_model_config() should not use_past",
                )

                self.assertTrue(
                    OnnxConfigWithPast.with_past(config()).use_past,
                    "OnnxConfigWithPast.from_model_config() should use_past",
                )