Пример #1
0
    def initialize(self, data: Dict[str, Any] = {}) -> None:
        """Initialize config from dict."""
        self.model_path = data.get("model_path", self.model_path)
        self.domain = data.get("domain", self.domain)

        if isinstance(data.get("model"), dict):
            self.model = Model(data.get("model", {}))

        # [Optional] One of "cpu", "gpu"; default cpu
        self.device = data.get("device", None)

        if isinstance(data.get("quantization"), dict):
            self.quantization = Quantization(data.get("quantization", {}))

        if isinstance(data.get("tuning"), dict):
            self.tuning = Tuning(data.get("tuning", {}))

        if isinstance(data.get("evaluation"), dict):
            self.evaluation = Evaluation(data.get("evaluation", {}))

        if isinstance(data.get("pruning"), dict):
            self.pruning = Pruning(data.get("pruning", {}))

        if isinstance(data.get("graph_optimization"), dict):
            self.graph_optimization = GraphOptimization(
                data.get("graph_optimization", {}))
Пример #2
0
    def test_graph_optimization_serializer(self) -> None:
        """Test Graph Optimization config serializer."""
        data = {
            "precisions": "bf16, fp32",
            "op_wise": {
                "weight": {
                    "granularity": "per_channel",
                    "scheme": "asym",
                    "dtype": "bf16",
                    "algorithm": "kl",
                },
                "activation": {
                    "granularity": "per_tensor",
                    "scheme": "sym",
                    "dtype": "int8",
                    "algorithm": "minmax",
                },
            },
        }
        graph_optimization = GraphOptimization(data)
        result = graph_optimization.serialize()

        self.assertDictEqual(
            result,
            {
                "precisions": "bf16,fp32",
                "op_wise": {
                    "weight": {
                        "granularity": "per_channel",
                        "scheme": "asym",
                        "dtype": "bf16",
                        "algorithm": "kl",
                    },
                    "activation": {
                        "granularity": "per_tensor",
                        "scheme": "sym",
                        "dtype": "int8",
                        "algorithm": "minmax",
                    },
                },
            },
        )
Пример #3
0
 def set_optimization_precision(self, framework: str,
                                precision: str) -> None:
     """Update graph optimization precision."""
     precisions_config = load_precisions_config().get(framework, [])
     available_precisions = [
         precision.get("name") for precision in precisions_config
     ]
     if precision not in available_precisions:
         raise ClientErrorException(
             f"Precision {precision} is not supported "
             f"in graph optimization for framework {framework}.", )
     if self.graph_optimization is None:
         self.graph_optimization = GraphOptimization(
             {"precisions": precision})
     else:
         self.graph_optimization.precisions = precision
Пример #4
0
    def test_graph_optimization_constructor(self) -> None:
        """Test Graph Optimization config constructor."""
        data = {
            "precisions": "bf16, fp32",
            "op_wise": {
                "weight": {
                    "granularity": "per_channel",
                    "scheme": "asym",
                    "dtype": "bf16",
                    "algorithm": "kl",
                },
                "activation": {
                    "granularity": "per_tensor",
                    "scheme": "sym",
                    "dtype": "int8",
                    "algorithm": "minmax",
                },
            },
        }
        graph_optimization = GraphOptimization(data)

        self.assertEqual(graph_optimization.precisions, "bf16,fp32")
        self.assertIsNotNone(graph_optimization.op_wise)
        self.assertDictEqual(
            graph_optimization.op_wise,
            {
                "weight": {
                    "granularity": "per_channel",
                    "scheme": "asym",
                    "dtype": "bf16",
                    "algorithm": "kl",
                },
                "activation": {
                    "granularity": "per_tensor",
                    "scheme": "sym",
                    "dtype": "int8",
                    "algorithm": "minmax",
                },
            },
        )
Пример #5
0
 def test_set_precisions_error(self) -> None:
     """Test overwriting precisions in Graph Optimization config."""
     graph_optimization = GraphOptimization()
     with self.assertRaises(ClientErrorException):
         graph_optimization.set_precisions(1)
Пример #6
0
 def test_set_precisions_list(self) -> None:
     """Test setting precisions in Graph Optimization config."""
     graph_optimization = GraphOptimization()
     graph_optimization.set_precisions(["bf16", "fp32 ", " int8"])
     self.assertEqual(graph_optimization.precisions, "bf16,fp32,int8")
Пример #7
0
 def test_set_precisions_string(self) -> None:
     """Test setting precisions in Graph Optimization config."""
     graph_optimization = GraphOptimization()
     graph_optimization.set_precisions(" bf16, fp32 ")
     self.assertEqual(graph_optimization.precisions, "bf16,fp32")
Пример #8
0
    def test_graph_optimization_constructor_defaults(self) -> None:
        """Test Graph Optimization config constructor defaults."""
        graph_optimization = GraphOptimization()

        self.assertIsNone(graph_optimization.precisions)
        self.assertIsNone(graph_optimization.op_wise)
Пример #9
0
    def test_graph_optimization_serializer_defaults(self) -> None:
        """Test Graph Optimization config serializer."""
        graph_optimization = GraphOptimization()
        result = graph_optimization.serialize()

        self.assertDictEqual(result, {})