Esempio n. 1
0
    def test_pruning_constructor(self) -> None:
        """Test Pruning config constructor."""
        data = {
            "magnitude": {
                "weights": ["layer1.0.conv1.weight", "layer1.0.conv2.weight"],
                "method": "per_channel",
                "init_sparsity": 0.3,
                "target_sparsity": 0.5,
                "start_epoch": 1,
                "end_epoch": 3,
            },
            "start_epoch": 0,
            "end_epoch": 2,
            "frequency": 0.5,
            "init_sparsity": 0.25,
            "target_sparsity": 0.75,
        }
        pruning = Pruning(data)

        self.assertEqual(
            pruning.magnitude.weights,
            ["layer1.0.conv1.weight", "layer1.0.conv2.weight"],
        )
        self.assertEqual(pruning.magnitude.method, "per_channel")
        self.assertEqual(pruning.magnitude.init_sparsity, 0.3)
        self.assertEqual(pruning.magnitude.target_sparsity, 0.5)
        self.assertEqual(pruning.magnitude.start_epoch, 1)
        self.assertEqual(pruning.magnitude.end_epoch, 3)
        self.assertEqual(pruning.start_epoch, 0)
        self.assertEqual(pruning.end_epoch, 2)
        self.assertEqual(pruning.frequency, 0.5)
        self.assertEqual(pruning.init_sparsity, 0.25)
        self.assertEqual(pruning.target_sparsity, 0.75)
Esempio n. 2
0
    def initialize(self, data: Dict[str, Any] = {}) -> None:
        """Initialize config from dict."""
        self.model_path = data.get("model_path", self.model_path)
        self.domain = data.get("domain", self.domain)

        if isinstance(data.get("model"), dict):
            self.model = Model(data.get("model", {}))

        # [Optional] One of "cpu", "gpu"; default cpu
        self.device = data.get("device", None)

        if isinstance(data.get("quantization"), dict):
            self.quantization = Quantization(data.get("quantization", {}))

        if isinstance(data.get("tuning"), dict):
            self.tuning = Tuning(data.get("tuning", {}))

        if isinstance(data.get("evaluation"), dict):
            self.evaluation = Evaluation(data.get("evaluation", {}))

        if isinstance(data.get("pruning"), dict):
            self.pruning = Pruning(data.get("pruning", {}))

        if isinstance(data.get("graph_optimization"), dict):
            self.graph_optimization = GraphOptimization(
                data.get("graph_optimization", {}))
Esempio n. 3
0
    def test_pruning_serializer(self) -> None:
        """Test Pruning config constructor."""
        data = {
            "magnitude": {
                "weights": ["layer1.0.conv1.weight", "layer1.0.conv2.weight"],
                "method": "per_channel",
                "init_sparsity": 0.3,
                "target_sparsity": 0.5,
                "start_epoch": 1,
                "end_epoch": 3,
            },
            "start_epoch": 0,
            "end_epoch": 2,
            "frequency": 0.5,
            "init_sparsity": 0.25,
            "target_sparsity": 0.75,
        }
        pruning = Pruning(data)
        result = pruning.serialize()

        self.assertDictEqual(
            result,
            {
                "magnitude": {
                    "weights":
                    ["layer1.0.conv1.weight", "layer1.0.conv2.weight"],
                    "method": "per_channel",
                    "init_sparsity": 0.3,
                    "target_sparsity": 0.5,
                    "start_epoch": 1,
                    "end_epoch": 3,
                },
                "start_epoch": 0,
                "end_epoch": 2,
                "frequency": 0.5,
                "init_sparsity": 0.25,
                "target_sparsity": 0.75,
            },
        )
Esempio n. 4
0
    def test_pruning_constructor_defaults(self) -> None:
        """Test Pruning config constructor defaults."""
        pruning = Pruning()

        self.assertIsNone(pruning.magnitude.weights)
        self.assertIsNone(pruning.magnitude.method)
        self.assertIsNone(pruning.magnitude.init_sparsity)
        self.assertIsNone(pruning.magnitude.target_sparsity)
        self.assertIsNone(pruning.magnitude.start_epoch)
        self.assertIsNone(pruning.magnitude.end_epoch)
        self.assertIsNone(pruning.start_epoch)
        self.assertIsNone(pruning.end_epoch)
        self.assertIsNone(pruning.frequency)
        self.assertIsNone(pruning.init_sparsity)
        self.assertIsNone(pruning.target_sparsity)
Esempio n. 5
0
    def test_pruning_serializer_defaults(self) -> None:
        """Test Pruning config constructor defaults."""
        pruning = Pruning()
        result = pruning.serialize()

        self.assertDictEqual(result, {})