예제 #1
0
def test_mxnet_transformer_default_output_fn_invalid_content_type():
    t = MXNetTransformer()

    with pytest.raises(errors.UnsupportedFormatError) as e:
        t.default_output_fn(None, 'bad/content-type')
    assert 'Content type bad/content-type is not supported by this framework' in str(
        e)
def test_mxnet_transformer_default_input_fn_with_accelerator(decode, mx_ndarray, mx_eia):
    ndarray = Mock()
    mx_ndarray.return_value = ndarray

    t = MXNetTransformer()
    t.default_input_fn(Mock(), 'application/json')

    ndarray.as_in_context.assert_called_with(mx.cpu())
예제 #3
0
def test_mxnet_transformer_default_input_fn_with_npy(decode):
    input_data = Mock()
    content_type = 'application/x-npy'

    t = MXNetTransformer()
    deserialized_data = t.default_input_fn(input_data, content_type)

    decode.assert_called_with(input_data, content_type)
    assert deserialized_data == mx.nd.array([0])
예제 #4
0
def test_mxnet_transformer_default_output_fn(encode):
    prediction = mx.ndarray.zeros(1)
    accept = 'application/json'

    t = MXNetTransformer()
    response = t.default_output_fn(prediction, accept)

    flattened_prediction = prediction.asnumpy().tolist()
    encode.assert_called_with(flattened_prediction, accept)

    assert isinstance(response, worker.Response)
예제 #5
0
def test_mxnet_transformer_init():
    t = MXNetTransformer()

    assert t._model is None
    assert t._model_fn == transformer.default_model_fn
    assert t._input_fn == t.default_input_fn
    assert t._predict_fn == t.default_predict_fn
    assert t._output_fn == t.default_output_fn
    assert t.VALID_CONTENT_TYPES == (content_types.JSON, content_types.NPY)
def test_mxnet_transformer_init_with_args():
    model = Mock()
    model_fn = Mock()
    input_fn = Mock()
    predict_fn = Mock()
    output_fn = Mock()
    error_class = Mock()

    t = MXNetTransformer(model=model, model_fn=model_fn, input_fn=input_fn, predict_fn=predict_fn,
                         output_fn=output_fn, error_class=error_class)

    assert t._model == model
    assert t._model_fn == model_fn
    assert t._input_fn == input_fn
    assert t._predict_fn == predict_fn
    assert t._output_fn == output_fn
    assert t._error_class == error_class
예제 #7
0
def test_mxnet_transformer_initialize_with_model(transformer_initialize):
    t = MXNetTransformer(model=Mock())
    t.initialize()

    transformer_initialize.assert_not_called()
예제 #8
0
def test_mxnet_transformer_initialize_without_model(transformer_initialize):
    t = MXNetTransformer()
    t.initialize()

    transformer_initialize.assert_called_once()