Esempio n. 1
0
    def test_set_transform_on_empty_config(self) -> None:
        """Test set_transform."""
        config = Config()

        config.set_transform([
            {
                "name": "Some transform1",
                "params": {
                    "param12": True
                }
            },
            {
                "name": "SquadV1",
                "params": {
                    "param1": True
                }
            },
            {
                "name": "Some transform2",
                "params": {
                    "param123": True
                }
            },
        ], )

        self.assertIsNone(config.evaluation)
        self.assertIsNone(config.quantization)
Esempio n. 2
0
    def test_set_transform(self) -> None:
        """Test set_transform."""
        config = Config(self.predefined_config)

        config.set_transform([
            {
                "name": "Some transform1",
                "params": {
                    "param12": True
                }
            },
            {
                "name": "SquadV1",
                "params": {
                    "param1": True
                }
            },
            {
                "name": "Some transform2",
                "params": {
                    "param123": True
                }
            },
        ], )

        self.assertEqual(
            {
                "SquadV1": {
                    "param1": True
                },
            },
            config.evaluation.accuracy.postprocess.transform,
        )

        self.assertEqual(
            ["Some transform1", "Some transform2"],
            list(config.quantization.calibration.dataloader.transform.keys()),
        )
        self.assertEqual(
            ["Some transform1", "Some transform2"],
            list(config.evaluation.accuracy.dataloader.transform.keys()),
        )
        self.assertEqual(
            ["Some transform1", "Some transform2"],
            list(config.evaluation.performance.dataloader.transform.keys()),
        )