def test_mxnet_wrapper_thinc_model_subclass(mx_model): class CustomModel(Model): def fn(self) -> int: return 1337 model = MXNetWrapper(mx_model, model_class=CustomModel) assert isinstance(model, CustomModel) assert model.fn() == 1337
def test_mxnet_wrapper_gluon_sequential(): import mxnet as mx mx_model = mx.gluon.nn.Sequential() mx_model.add(mx.gluon.nn.Dense(12)) wrapped = MXNetWrapper(mx_model) assert isinstance(wrapped, Model)
def test_mxnet_wrapper_convert_inputs(data, n_args, kwargs_keys): import mxnet as mx mx_model = mx.gluon.nn.Sequential() mx_model.add(mx.gluon.nn.Dense(12)) mx_model.initialize() model = MXNetWrapper(mx_model) convert_inputs = model.attrs["convert_inputs"] Y, backprop = convert_inputs(model, data, is_train=True) check_input_converters(Y, backprop, data, n_args, kwargs_keys, mx.nd.NDArray)
def model(mx_model) -> Model[Array2d, Array2d]: return MXNetWrapper(mx_model)
def test_mxnet_wrapper_thinc_set_model_name(mx_model): model = MXNetWrapper(mx_model, model_name="cool") assert model.name == "cool"
def test_mxnet_wrapper_to_cpu(mx_model, X: Array2d): model = MXNetWrapper(mx_model) model.predict(X) model.to_cpu()