def test_mxnet_invalid_model_artifact_without_json_file(): model_artifacts = ["test.params"] data_shape = {"data": [1, 3, 224, 224]} loader = MxNetModelLoader(model_artifacts, data_shape) with pytest.raises(RuntimeError) as errinfo: loader.load_model() assert "InputConfiguration: No symbol file found for MXNet model." in str(errinfo.value)
def test_mxnet_load_model_exception(patch_relay, patch_mxnet, patch_tvm_error): patch_relay.frontend.from_mxnet.side_effect = Exception("Some TVM Error") model_artifacts = ["test.params", "test.json"] data_shape = {"data": [1, 3, 224, 224]} loader = MxNetModelLoader(model_artifacts, data_shape) with pytest.raises(RuntimeError, match="InputConfiguration: TVM can't convert the MXNet model."): loader.load_model()
def test_mxnet_invalid_model_artifact_with_multiple_json_files(): model_artifacts = ["test.params", "test1.json", "test2.json"] data_shape = {"data": [1, 3, 224, 224]} loader = MxNetModelLoader(model_artifacts, data_shape) with pytest.raises(RuntimeError) as errinfo: loader.load_model() assert "InputConfiguration: Only one symbol file is allowed for MXNet model." in str(errinfo.value)
def test_mxnet_load_model_op_error(patch_relay, patch_mxnet, patch_tvm_error): patch_relay.frontend.from_mxnet.side_effect = MockedOpError model_artifacts = ["test.params", "test.json"] data_shape = {"data": [1, 3, 224, 224]} loader = MxNetModelLoader(model_artifacts, data_shape) with pytest.raises(MockedOpError): loader.load_model()
def test_mxnet_symbol_exception(patch_mxnet): patch_mxnet.symbol.load.side_effect = Exception("Bad model json.") model_artifacts = ["test.params", "test.json"] data_shape = {"data": [1, 3, 224, 224]} loader = MxNetModelLoader(model_artifacts, data_shape) with pytest.raises(RuntimeError) as errinfo: loader.load_model() assert "InputConfiguration: Framework can't load the MXNet model: Bad model json." in str(errinfo.value)
def test_mxnet(patch_relay, patch_mxnet): patch_relay.frontend.from_mxnet.return_value.__iter__.return_value = MagicMock(), MagicMock() model_artifacts = ["test.params", "test.json"] data_shape = {"data": [1, 3, 224, 224]} loader = MxNetModelLoader(model_artifacts, data_shape) loader.load_model() patch_mxnet.symbol.load.assert_called() patch_mxnet.ndarray.load.assert_called() patch_relay.frontend.from_mxnet.assert_called()
def test_mxnet_model_artifact_with_multiple_params_files(patch_relay, patch_mxnet): patch_relay.frontend.from_mxnet.return_value.__iter__.return_value = MagicMock(), MagicMock() model_artifacts = ["resnet-18-symbol.json", "resnet-18-0000.params", "resnet-18-0042.params"] data_shape = {"data": [1, 3, 224, 224]} loader = MxNetModelLoader(model_artifacts, data_shape) loader.load_model() patch_mxnet.symbol.load.assert_called() patch_mxnet.ndarray.load.assert_called() patch_relay.frontend.from_mxnet.assert_called()
def test_mxnet_ndarray_invalid_params_exception(patch_mxnet): def mock_mxnet_ndarray_load(param_file): return {"arg:foo": param_file, "bad:aux": param_file} patch_mxnet.ndarray.load = mock_mxnet_ndarray_load model_artifacts = ["test.params", "test.json"] data_shape = {"data": [1, 3, 224, 224]} loader = MxNetModelLoader(model_artifacts, data_shape) msg = "InputConfiguration: Framework can't load the MXNet model: Please use HybridBlock.export()" with pytest.raises(RuntimeError) as errinfo: loader.load_model() assert msg in str(errinfo.value)