예제 #1
0
def test_mxnet_wrapper_thinc_model_subclass(mx_model):
    class CustomModel(Model):
        def fn(self) -> int:
            return 1337

    model = MXNetWrapper(mx_model, model_class=CustomModel)
    assert isinstance(model, CustomModel)
    assert model.fn() == 1337
예제 #2
0
def test_mxnet_wrapper_gluon_sequential():
    import mxnet as mx

    mx_model = mx.gluon.nn.Sequential()
    mx_model.add(mx.gluon.nn.Dense(12))
    wrapped = MXNetWrapper(mx_model)
    assert isinstance(wrapped, Model)
예제 #3
0
def test_mxnet_wrapper_convert_inputs(data, n_args, kwargs_keys):
    import mxnet as mx

    mx_model = mx.gluon.nn.Sequential()
    mx_model.add(mx.gluon.nn.Dense(12))
    mx_model.initialize()
    model = MXNetWrapper(mx_model)
    convert_inputs = model.attrs["convert_inputs"]
    Y, backprop = convert_inputs(model, data, is_train=True)
    check_input_converters(Y, backprop, data, n_args, kwargs_keys,
                           mx.nd.NDArray)
예제 #4
0
def model(mx_model) -> Model[Array2d, Array2d]:
    return MXNetWrapper(mx_model)
예제 #5
0
def test_mxnet_wrapper_thinc_set_model_name(mx_model):
    model = MXNetWrapper(mx_model, model_name="cool")
    assert model.name == "cool"
예제 #6
0
def test_mxnet_wrapper_to_cpu(mx_model, X: Array2d):
    model = MXNetWrapper(mx_model)
    model.predict(X)
    model.to_cpu()