def test_input_fn_npz(np_array): input_data = encoders.array_to_npy(np_array) deserialized_np_array = serving.default_input_fn(input_data, content_types.NPY) assert np.array_equal(np_array, deserialized_np_array) float_32_array = np.array(np_array, dtype=np.float32) input_data = encoders.array_to_npy(float_32_array) deserialized_np_array = serving.default_input_fn(input_data, content_types.NPY) assert np.array_equal(float_32_array, deserialized_np_array) float_64_array = np.array(np_array, dtype=np.float64) input_data = encoders.array_to_npy(float_64_array) deserialized_np_array = serving.default_input_fn(input_data, content_types.NPY) assert np.array_equal(float_64_array, deserialized_np_array)
def test_input_fn_bad_content_type(): with pytest.raises(errors.UnsupportedFormatError): serving.default_input_fn('', 'application/not_supported')
def test_input_fn_csv(csv_data, expected): deserialized_np_array = serving.default_input_fn(csv_data, content_types.CSV) assert np.array_equal(expected, deserialized_np_array)
def test_input_fn_json(json_data, expected): actual = serving.default_input_fn(json_data, content_types.JSON) np.testing.assert_equal(actual, expected)