def test_tuning_serializer(self) -> None: """Test Tuning config serializer.""" data = { "strategy": { "name": "mse", "accuracy_weight": 0.5, "latency_weight": 1.0, }, "accuracy_criterion": { "relative": 0.01, "absolute": 0.02, }, "objective": "performance", "exit_policy": { "timeout": 60, "max_trials": 200, }, "random_seed": 12345, "tensorboard": True, "workspace": { "path": "/path/to/workspace", "resume": "/path/to/snapshot/file", }, "additional_field": { "key": "val" }, } tuning = Tuning(data) result = tuning.serialize() self.assertDictEqual( result, { "strategy": { "name": "mse", "accuracy_weight": 0.5, "latency_weight": 1.0, }, "accuracy_criterion": { "relative": 0.01, "absolute": 0.02, }, "objective": "performance", "exit_policy": { "timeout": 60, "max_trials": 200, }, "random_seed": 12345, "tensorboard": True, "workspace": { "path": "/path/to/workspace", "resume": "/path/to/snapshot/file", }, }, )
def test_set_timeout_with_exit_policy(self) -> None: """Test overwriting timeout in Tuning config.""" tuning = Tuning({ "exit_policy": { "timeout": 60, }, }, ) self.assertIsNotNone(tuning.exit_policy) self.assertEqual(tuning.exit_policy.timeout, 60) tuning.set_timeout(10) self.assertIsNotNone(tuning.exit_policy) self.assertEqual(tuning.exit_policy.timeout, 10)
def test_set_max_trials_with_exit_policy(self) -> None: """Test overwriting max_trials in Tuning config.""" tuning = Tuning({ "exit_policy": { "max_trials": 60, }, }, ) self.assertIsNotNone(tuning.exit_policy) self.assertEqual(tuning.exit_policy.max_trials, 60) tuning.set_max_trials(10) self.assertIsNotNone(tuning.exit_policy) self.assertEqual(tuning.exit_policy.max_trials, 10)
def initialize(self, data: Dict[str, Any] = {}) -> None: """Initialize config from dict.""" self.model_path = data.get("model_path", self.model_path) self.domain = data.get("domain", self.domain) if isinstance(data.get("model"), dict): self.model = Model(data.get("model", {})) # [Optional] One of "cpu", "gpu"; default cpu self.device = data.get("device", None) if isinstance(data.get("quantization"), dict): self.quantization = Quantization(data.get("quantization", {})) if isinstance(data.get("tuning"), dict): self.tuning = Tuning(data.get("tuning", {})) if isinstance(data.get("evaluation"), dict): self.evaluation = Evaluation(data.get("evaluation", {})) if isinstance(data.get("pruning"), dict): self.pruning = Pruning(data.get("pruning", {})) if isinstance(data.get("graph_optimization"), dict): self.graph_optimization = GraphOptimization( data.get("graph_optimization", {}))
def __init__(self, data: Dict[str, Any] = {}): """Initialize Configuration class.""" super().__init__() self._skip.append("model_path") self.model_path: str = data.get("model_path", "") self.model: Model = Model() self.domain: Optional[str] = data.get("domain", None) self.device: Optional[str] = None self.quantization: Optional[Quantization] = None self.tuning: Tuning = Tuning() self.evaluation: Optional[Evaluation] = None self.pruning: Optional[Pruning] = None self.initialize(data)
def test_tuning_constructor(self) -> None: """Test Tuning config constructor.""" data = { "strategy": { "name": "mse", "accuracy_weight": 0.5, "latency_weight": 1.0, }, "accuracy_criterion": { "relative": 0.01, "absolute": 0.02, }, "objective": "performance", "exit_policy": { "timeout": 60, "max_trials": 200, }, "random_seed": 12345, "tensorboard": True, "workspace": { "path": "/path/to/workspace", "resume": "/path/to/snapshot/file", }, } tuning = Tuning(data) self.assertIsNotNone(tuning.strategy) self.assertEqual(tuning.strategy.name, "mse") self.assertEqual(tuning.strategy.accuracy_weight, 0.5) self.assertEqual(tuning.strategy.latency_weight, 1.0) self.assertIsNotNone(tuning.accuracy_criterion) self.assertEqual(tuning.accuracy_criterion.relative, 0.01) self.assertEqual(tuning.accuracy_criterion.absolute, 0.02) self.assertEqual(tuning.objective, "performance") self.assertIsNotNone(tuning.exit_policy) self.assertEqual(tuning.exit_policy.timeout, 60) self.assertEqual(tuning.exit_policy.max_trials, 200) self.assertEqual(tuning.random_seed, 12345) self.assertTrue(tuning.tensorboard) self.assertIsNotNone(tuning.workspace) self.assertEqual(tuning.workspace.path, "/path/to/workspace") self.assertEqual(tuning.workspace.resume, "/path/to/snapshot/file")
def test_tuning_constructor_defaults(self) -> None: """Test Tuning config constructor defaults.""" tuning = Tuning() self.assertIsNotNone(tuning.strategy) self.assertEqual(tuning.strategy.name, "basic") self.assertIsNone(tuning.strategy.accuracy_weight) self.assertIsNone(tuning.strategy.latency_weight) self.assertIsNotNone(tuning.accuracy_criterion) self.assertIsNone(tuning.accuracy_criterion.relative) self.assertIsNone(tuning.accuracy_criterion.absolute) self.assertIsNone(tuning.objective) self.assertIsNone(tuning.exit_policy) self.assertIsNone(tuning.random_seed) self.assertIsNone(tuning.tensorboard) self.assertIsNotNone(tuning.workspace) self.assertIsNone(tuning.workspace.path) self.assertIsNone(tuning.workspace.resume)
def test_set_random_seed_invalid_string(self) -> None: """Test setting random_seed from invalid string in Tuning config.""" tuning = Tuning() with self.assertRaises(ClientErrorException): tuning.set_random_seed("abc")
def test_set_random_seed_from_string(self) -> None: """Test setting random_seed from string in Tuning config.""" tuning = Tuning() tuning.set_random_seed("123456") self.assertEqual(tuning.random_seed, 123456)
def test_set_random_seed(self) -> None: """Test setting random_seed in Tuning config.""" tuning = Tuning() tuning.set_random_seed(123456) self.assertEqual(tuning.random_seed, 123456)
def test_set_max_trials_invalid_string(self) -> None: """Test overwriting max_trials in Tuning config.""" tuning = Tuning() with self.assertRaises(ClientErrorException): tuning.set_max_trials("abc")
def test_set_max_trials_negative(self) -> None: """Test overwriting max_trials in Tuning config.""" tuning = Tuning() with self.assertRaises(ClientErrorException): tuning.set_max_trials(-1)
def test_set_max_trials_from_string(self) -> None: """Test overwriting max_trials in Tuning config.""" tuning = Tuning() tuning.set_max_trials("10") self.assertIsNotNone(tuning.exit_policy) self.assertEqual(tuning.exit_policy.max_trials, 10)
def test_set_max_trials(self) -> None: """Test setting max_trials in Tuning config.""" tuning = Tuning() tuning.set_max_trials(10) self.assertIsNotNone(tuning.exit_policy) self.assertEqual(tuning.exit_policy.max_trials, 10)
def test_set_timeout_negative(self) -> None: """Test overwriting timeout in Tuning config.""" tuning = Tuning() with self.assertRaises(ClientErrorException): tuning.set_timeout(-1)
def test_set_timeout_from_string(self) -> None: """Test overwriting timeout in Tuning config.""" tuning = Tuning() tuning.set_timeout("10") self.assertIsNotNone(tuning.exit_policy) self.assertEqual(tuning.exit_policy.timeout, 10)
def test_set_timeout(self) -> None: """Test setting timeout in Tuning config.""" tuning = Tuning() tuning.set_timeout(10) self.assertIsNotNone(tuning.exit_policy) self.assertEqual(tuning.exit_policy.timeout, 10)