def test_get_transforms( self, mocked_load_transforms_config: MagicMock, mocked_load_transforms_filter_config: MagicMock, ) -> None: """Test that get_transforms works.""" mocked_load_transforms_config.return_value = [ { "name": "framework_foo", "help": None, "params": [ {"name": "transform1"}, {"name": "transform2"}, {"name": "transform3"}, ], }, { "name": "framework_bar", "help": None, "params": [ {"name": "transform4"}, {"name": "transform5"}, {"name": "transform6"}, ], }, ] mocked_load_transforms_filter_config.return_value = { "framework_foo": { "domain_foo": [ "transform1", "transform2", "transform3", "transform4", ], "domain_bar": [ "transform2", "transform3", "transform4", ], }, } expected = [ {"name": "transform2"}, {"name": "transform3"}, ] feeder = Feeder( data={ "config": { "framework": "framework_foo", "domain": "domain_bar", }, }, ) actual = feeder.get_transforms() mocked_load_transforms_config.assert_called_once() mocked_load_transforms_filter_config.assert_called_once() self.assertEqual(expected, actual)
def test_get_precisions_for_missing_framework( self, mocked_load_precisions_config: MagicMock, ) -> None: """Test get_precisions function.""" mocked_load_precisions_config.return_value = { "framework_foo": [ { "foo": {} }, { "bar": {} }, { "baz": {} }, ], "framework_bar": [ { "foo": {} }, ], } feeder = Feeder(data={}) with self.assertRaisesRegex(ClientErrorException, "Framework not set."): feeder.get_precisions()
def test_get_precisions_for_unknown_framework( self, mocked_load_precisions_config: MagicMock, ) -> None: """Test get_precisions function.""" mocked_load_precisions_config.return_value = { "framework_foo": [ {"foo": {}}, {"bar": {}}, {"baz": {}}, ], "framework_bar": [ {"foo": {}}, ], } feeder = Feeder( data={ "config": { "framework": "framework_baz", }, }, ) actual = feeder.get_precisions() self.assertEqual([], actual)
def test_get_dataloaders_for_unknown_framework( self, mocked_load_dataloader_config: MagicMock, ) -> None: """Test that get_dataloaders works when unknown framework requested.""" mocked_load_dataloader_config.return_value = [ { "name": "framework_foo", "help": None, "params": {}, }, { "name": "framework_bar", "help": None, "params": {}, }, ] feeder = Feeder( data={ "config": { "framework": "framework_unknown", }, }, ) actual = feeder.get_dataloaders() mocked_load_dataloader_config.assert_called_once() self.assertEqual([], actual)
def test_get_metrics_for_unknown_framework( self, mocked_check_module: MagicMock, mocked_load_help_lpot_params: MagicMock, ) -> None: """Test that get_domains fails when no config given.""" mocked_load_help_lpot_params.return_value = { "__help__topk": "help for topk", "topk": { "__help__k": "help for k in topk", "__help__missing_param": "help for missing_param in topk", }, "__help__metric1": "help for metric1", "__help__metric3": "help for metric3", } expected = [ { "name": "custom", "help": "", "value": None, }, ] feeder = Feeder(data={ "config": { "framework": "unknown_framework", }, }, ) actual = feeder.get_metrics() mocked_check_module.assert_called_once_with("unknown_framework") mocked_load_help_lpot_params.assert_called_once_with("metrics") self.assertEqual(expected, actual)
def test_get_models( self, mocked_load_model_config: MagicMock, ) -> None: """Test get_models method.""" mocked_load_model_config.return_value = { "__help__framework_foo": "help for framework_foo", "framework_foo": { "__help__domain1": "help text for framework_foo/domain1", "domain1": {}, "__help__domain2": "help text for framework_foo/domain2", "domain2": { "__help__model1": "help for model 1", "model1": {}, "__help__model2": "help for model 2", "model2": {}, "__help__model3": "help for model 3", "model3": {}, }, "__help__domain3": "help text for framework_foo/domain3", "domain3": {}, }, "__help__framework_bar": "help for framework_bar", "framework_bar": { "__help__domain1": "help text for framework_bar/domain1", "domain1": {}, }, "__help__framework_baz": "help for framework_baz", "framework_baz": { "__help__domain1": "help text for framework_baz/domain1", "domain1": {}, }, } expected = [ { "name": "model1", "help": "help for model 1", }, { "name": "model2", "help": "help for model 2", }, { "name": "model3", "help": "help for model 3", }, ] feeder = Feeder(data={ "config": { "framework": "framework_foo", "domain": "domain2", }, }, ) actual = feeder.get_models() mocked_load_model_config.assert_called_once() self.assertEqual(expected, actual)
def test_feed_with_unknown_param_fails(self) -> None: """Test that calling feed with not supported param fails.""" feeder = Feeder(data={"param": "foo"}) with self.assertRaisesRegex( ClientErrorException, "Could not found method for foo parameter.", ): feeder.feed()
def test_get_models_fails_without_domain(self) -> None: """Test that get_models fails when no config given.""" feeder = Feeder(data={ "config": { "framework": "framework_foo", }, }, ) with self.assertRaisesRegex(ClientErrorException, "Domain not set."): feeder.get_models()
def test_get_quantization_approaches_for_fake_framework(self) -> None: """Test get_quantization_approaches.""" feeder = Feeder(data={ "config": { "framework": "framework_foo", }, }, ) output = feeder.get_quantization_approaches() quantization_names = [approach.get("name") for approach in output] self.assertEqual(["post_training_static_quant"], quantization_names)
def test_feed_for_framework_works(self, mocked_get_frameworks: MagicMock) -> None: """Test that calling feed with not supported param fails.""" frameworks = [ {"name": "framework1"}, {"name": "framework2"}, {"name": "framework3"}, ] mocked_get_frameworks.return_value = frameworks expected = {"framework": frameworks} feeder = Feeder(data={"param": "framework"}) actual = feeder.feed() self.assertEqual(expected, actual)
def test_get_strategies( self, mocked_load_help_lpot_params: MagicMock, ) -> None: """Test get_strategies function.""" mocked_load_help_lpot_params.return_value = { "__help__strategy1": "help1", "__help__strategy_unknown": "this should be skipped", "__help__strategy2": "help2", } expected = [ { "name": "strategy1", "help": "help1", }, { "name": "strategy2", "help": "help2", }, { "name": "strategy3", "help": "", }, ] actual = Feeder.get_strategies() mocked_load_help_lpot_params.assert_called_once_with("strategies") self.assertEqual(expected, actual)
def test_get_transforms_without_domain( self, mocked_load_transforms_config: MagicMock, ) -> None: """Test that get_transforms works when no domain requested.""" mocked_load_transforms_config.return_value = [ { "name": "framework_foo", "help": None, "params": [ {"name": "transform1"}, {"name": "transform2"}, {"name": "transform3"}, ], }, { "name": "framework_bar", "help": None, "params": [ {"name": "transform4"}, {"name": "transform5"}, {"name": "transform6"}, ], }, ] expected = [ {"name": "transform1"}, {"name": "transform2"}, {"name": "transform3"}, ] feeder = Feeder( data={ "config": { "framework": "framework_foo", }, }, ) actual = feeder.get_transforms() mocked_load_transforms_config.assert_called_once() self.assertEqual(expected, actual)
def test_get_dataloaders( self, mocked_load_dataloader_config: MagicMock, ) -> None: """Test that get_dataloaders works.""" params = [ { "name": "param1" }, { "name": "param2" }, { "name": "param3" }, ] mocked_load_dataloader_config.return_value = [ { "name": "framework_foo", "help": None, "params": params, }, { "name": "framework_bar", "help": None, "params": {}, }, ] expected = params feeder = Feeder(data={ "config": { "framework": "framework_foo", }, }, ) actual = feeder.get_dataloaders() mocked_load_dataloader_config.assert_called_once() self.assertEqual(expected, actual)
def test_get_frameworks( self, mocked_load_model_config: MagicMock, mocked_get_frameworks: MagicMock, ) -> None: """Test get_frameworks function.""" mocked_load_model_config.return_value = { "__help__framework_foo": "framework_foo is in known frameworks, so should be inculded", "framework_foo": { "__help__domain1": "help text for framework_foo/domain1", "domain1": {}, }, "__help__framework_bar": "framework_bar is not known, so should be ignored", "framework_bar": { "__help__domain1": "help text for framework_bar/domain1", "domain1": {}, }, "__help__framework_baz": "framework_baz is in known frameworks, so should be inculded", "framework_baz": { "__help__domain1": "help text for framework_baz/domain1", "domain1": {}, }, } mocked_get_frameworks.return_value = [ "framework_baz", "framework_foo", ] expected = [ { "name": "framework_foo", "help": "framework_foo is in known frameworks, so should be inculded", }, { "name": "framework_baz", "help": "framework_baz is in known frameworks, so should be inculded", }, ] actual = Feeder.get_frameworks() self.assertEqual(expected, actual)
def test_get_metrics_for_onnxrt( self, mocked_check_module: MagicMock, mocked_load_help_lpot_params: MagicMock, ) -> None: """Test that get_domains fails when no config given.""" mocked_load_help_lpot_params.return_value = { "__help__topk": "help for topk", "topk": { "__help__k": "help for k in topk", "__help__missing_param": "help for missing_param in topk", }, "__help__metric1": "help for metric1", "__help__metric3": "help for metric3", } expected = [ { "name": "topk", "help": "help for topk", "params": [ { "name": "k", "help": "help for k in topk", "value": [1, 5], }, ], }, { "name": "COCOmAP", "help": "", "params": [ { "name": "anno_path", "help": "", "value": "", }, ], }, { "name": "MSE", "help": "", "params": [ { "name": "compare_label", "help": "", "value": True, }, ], }, { "name": "RMSE", "help": "", "params": [ { "name": "compare_label", "help": "", "value": True, }, ], }, { "name": "MAE", "help": "", "params": [ { "name": "compare_label", "help": "", "value": True, }, ], }, { "name": "metric1", "help": "help for metric1", "value": None, }, { "name": "custom", "help": "", "value": None, }, ] feeder = Feeder( data={ "config": { "framework": "onnxrt", }, }, ) actual = feeder.get_metrics() mocked_check_module.assert_called_once_with("onnxrt") mocked_load_help_lpot_params.assert_called_once_with("metrics") self.assertEqual(expected, actual)
def test_get_metrics_fails_without_framework(self) -> None: """Test that get_domains fails when no config given.""" feeder = Feeder(data={}) with self.assertRaisesRegex(ClientErrorException, "Framework not set."): feeder.get_metrics()
def test_feed_without_param_failes(self) -> None: """Test that calling feed when param is not set fails.""" feeder = Feeder(data={}) with self.assertRaisesRegex(ClientErrorException, "Parameter not defined."): feeder.feed()