コード例 #1
0
def test_input_fn_npz(np_array):
    input_data = encoders.array_to_npy(np_array)
    deserialized_np_array = serving.default_input_fn(input_data,
                                                     content_types.NPY)

    assert np.array_equal(np_array, deserialized_np_array)

    float_32_array = np.array(np_array, dtype=np.float32)
    input_data = encoders.array_to_npy(float_32_array)
    deserialized_np_array = serving.default_input_fn(input_data,
                                                     content_types.NPY)

    assert np.array_equal(float_32_array, deserialized_np_array)

    float_64_array = np.array(np_array, dtype=np.float64)
    input_data = encoders.array_to_npy(float_64_array)
    deserialized_np_array = serving.default_input_fn(input_data,
                                                     content_types.NPY)

    assert np.array_equal(float_64_array, deserialized_np_array)
コード例 #2
0
def test_input_fn_bad_content_type():
    with pytest.raises(errors.UnsupportedFormatError):
        serving.default_input_fn('', 'application/not_supported')
コード例 #3
0
def test_input_fn_csv(csv_data, expected):
    deserialized_np_array = serving.default_input_fn(csv_data,
                                                     content_types.CSV)
    assert np.array_equal(expected, deserialized_np_array)
コード例 #4
0
def test_input_fn_json(json_data, expected):
    actual = serving.default_input_fn(json_data, content_types.JSON)
    np.testing.assert_equal(actual, expected)