def test_mxnet_transformer_default_input_fn_invalid_content_type(): t = MXNetTransformer() with pytest.raises(errors.UnsupportedFormatError) as e: t.default_input_fn(None, 'bad/content-type') assert 'Content type bad/content-type is not supported by this framework' in str( e)
def test_mxnet_transformer_default_input_fn_with_accelerator(decode, mx_ndarray, mx_eia): ndarray = Mock() mx_ndarray.return_value = ndarray t = MXNetTransformer() t.default_input_fn(Mock(), 'application/json') ndarray.as_in_context.assert_called_with(mx.cpu())
def test_mxnet_transformer_default_input_fn_with_npy(decode): input_data = Mock() content_type = 'application/x-npy' t = MXNetTransformer() deserialized_data = t.default_input_fn(input_data, content_type) decode.assert_called_with(input_data, content_type) assert deserialized_data == mx.nd.array([0])