def test_mxnet_transformer_default_output_fn_invalid_content_type(): t = MXNetTransformer() with pytest.raises(errors.UnsupportedFormatError) as e: t.default_output_fn(None, 'bad/content-type') assert 'Content type bad/content-type is not supported by this framework' in str( e)
def test_mxnet_transformer_default_input_fn_with_accelerator(decode, mx_ndarray, mx_eia): ndarray = Mock() mx_ndarray.return_value = ndarray t = MXNetTransformer() t.default_input_fn(Mock(), 'application/json') ndarray.as_in_context.assert_called_with(mx.cpu())
def test_mxnet_transformer_default_input_fn_with_npy(decode): input_data = Mock() content_type = 'application/x-npy' t = MXNetTransformer() deserialized_data = t.default_input_fn(input_data, content_type) decode.assert_called_with(input_data, content_type) assert deserialized_data == mx.nd.array([0])
def test_mxnet_transformer_default_output_fn(encode): prediction = mx.ndarray.zeros(1) accept = 'application/json' t = MXNetTransformer() response = t.default_output_fn(prediction, accept) flattened_prediction = prediction.asnumpy().tolist() encode.assert_called_with(flattened_prediction, accept) assert isinstance(response, worker.Response)
def test_mxnet_transformer_init(): t = MXNetTransformer() assert t._model is None assert t._model_fn == transformer.default_model_fn assert t._input_fn == t.default_input_fn assert t._predict_fn == t.default_predict_fn assert t._output_fn == t.default_output_fn assert t.VALID_CONTENT_TYPES == (content_types.JSON, content_types.NPY)
def test_mxnet_transformer_init_with_args(): model = Mock() model_fn = Mock() input_fn = Mock() predict_fn = Mock() output_fn = Mock() error_class = Mock() t = MXNetTransformer(model=model, model_fn=model_fn, input_fn=input_fn, predict_fn=predict_fn, output_fn=output_fn, error_class=error_class) assert t._model == model assert t._model_fn == model_fn assert t._input_fn == input_fn assert t._predict_fn == predict_fn assert t._output_fn == output_fn assert t._error_class == error_class
def test_mxnet_transformer_initialize_with_model(transformer_initialize): t = MXNetTransformer(model=Mock()) t.initialize() transformer_initialize.assert_not_called()
def test_mxnet_transformer_initialize_without_model(transformer_initialize): t = MXNetTransformer() t.initialize() transformer_initialize.assert_called_once()