예제 #1
0
def test_user_module_transformer_with_transform_fn(model_fn):
    class UserModule:
        def __init__(self):
            self.transform_fn = Mock()

    user_module = UserModule()

    t = _user_module_transformer(user_module, MODEL_DIR)
    assert t._transform_fn == user_module.transform_fn
예제 #2
0
def test_user_module_transformer_module_transformer_no_user_methods(model_fn):
    module = mx.module.BaseModule()
    model_fn.return_value = module

    user_module = None
    t = _user_module_transformer(user_module, MODEL_DIR)

    assert isinstance(t, ModuleTransformer)
    assert t._model == module
    assert t._model_fn == model_fn
    assert t._input_fn == t.default_input_fn
    assert t._predict_fn == t.default_predict_fn
    assert t._output_fn == t.default_output_fn
예제 #3
0
def test_user_module_transformer_gluon_transformer_with_user_methods():
    gluon_block = mx.gluon.block.Block()

    class UserModule:
        def __init__(self):
            self.input_fn = Mock()
            self.predict_fn = Mock()
            self.output_fn = Mock()

        def model_fn(self, model_dir):
            return gluon_block

    user_module = UserModule()
    t = _user_module_transformer(user_module, MODEL_DIR)

    assert isinstance(t, GluonBlockTransformer)
    assert t._model == gluon_block
    assert t._model_fn == user_module.model_fn
    assert t._input_fn == user_module.input_fn
    assert t._predict_fn == user_module.predict_fn
    assert t._output_fn == user_module.output_fn
예제 #4
0
def test_user_module_transformer_unsupported_model_type(model_fn):
    user_module = None
    with pytest.raises(ValueError) as e:
        _user_module_transformer(user_module, MODEL_DIR)

    assert 'Unsupported model type' in str(e)