def test_user_module_transformer_with_transform_fn(model_fn): class UserModule: def __init__(self): self.transform_fn = Mock() user_module = UserModule() t = _user_module_transformer(user_module, MODEL_DIR) assert t._transform_fn == user_module.transform_fn
def test_user_module_transformer_module_transformer_no_user_methods(model_fn): module = mx.module.BaseModule() model_fn.return_value = module user_module = None t = _user_module_transformer(user_module, MODEL_DIR) assert isinstance(t, ModuleTransformer) assert t._model == module assert t._model_fn == model_fn assert t._input_fn == t.default_input_fn assert t._predict_fn == t.default_predict_fn assert t._output_fn == t.default_output_fn
def test_user_module_transformer_gluon_transformer_with_user_methods(): gluon_block = mx.gluon.block.Block() class UserModule: def __init__(self): self.input_fn = Mock() self.predict_fn = Mock() self.output_fn = Mock() def model_fn(self, model_dir): return gluon_block user_module = UserModule() t = _user_module_transformer(user_module, MODEL_DIR) assert isinstance(t, GluonBlockTransformer) assert t._model == gluon_block assert t._model_fn == user_module.model_fn assert t._input_fn == user_module.input_fn assert t._predict_fn == user_module.predict_fn assert t._output_fn == user_module.output_fn
def test_user_module_transformer_unsupported_model_type(model_fn): user_module = None with pytest.raises(ValueError) as e: _user_module_transformer(user_module, MODEL_DIR) assert 'Unsupported model type' in str(e)