def test_from_module_complete(self, mxc): from mxnet_container.serve.transformer import MXNetTransformer t = MXNetTransformer.from_module(generic_module()) assert isinstance(t, MXNetTransformer) assert t.model == generic_model assert t.transform_fn == generic_transform_fn assert t.transform('x', 'content-type', 'accept') == ('x', 'accept')
def test_from_module(self, select, mxc, module_module): from mxnet_container.serve.transformer import MXNetTransformer, ModuleTransformer select.return_value = ModuleTransformer t = MXNetTransformer.from_module(module_module) assert isinstance(t, ModuleTransformer) assert t.model == module_module._module assert t.transform('x', JSON_CONTENT_TYPE, JSON_CONTENT_TYPE) == \ ('output(predict(input(x)))', JSON_CONTENT_TYPE)
def test_from_module_with_default_model_fn(self, model_fn, mxc, generic_module): from mxnet_container.serve.transformer import MXNetTransformer model_fn.return_value = generic_model del generic_module.model_fn t = MXNetTransformer.from_module(generic_module) # expect MXNetTransformer with transform_fn from module, model from default_model_fn assert isinstance(t, MXNetTransformer) assert t.model == generic_model assert t.transform_fn == generic_transform_fn
def test_from_module_with_defaults(self, input_fn, predict_fn, output_fn, select, mxc, gluon_module): from mxnet_container.serve.transformer import MXNetTransformer, GluonBlockTransformer select.return_value = GluonBlockTransformer # remove the handlers so we can test default handlers del gluon_module.input_fn del gluon_module.predict_fn del gluon_module.output_fn input_fn.return_value = 'default_input' predict_fn.return_value = 'default_predict' output_fn.return_value = 'default_output', 'accept' t = MXNetTransformer.from_module(gluon_module) assert isinstance(t, GluonBlockTransformer) assert t.model == gluon_module._block assert t.transform('x', JSON_CONTENT_TYPE, JSON_CONTENT_TYPE) == \ ('default_output', 'accept') input_fn.assert_called_with('x', JSON_CONTENT_TYPE) predict_fn.assert_called_with(gluon_module._block, 'default_input') output_fn.assert_called_with('default_predict', JSON_CONTENT_TYPE)