예제 #1
0
def test_default_input_fn_csv_bad_columns():
    str_io = StringIO()
    csv_writer = csv.writer(str_io, delimiter=',')
    csv_writer.writerow([1, 2, 3])
    csv_writer.writerow([1, 2, 3, 4])

    with pytest.raises(ValueError):
        default_input_fn(str_io.getvalue(), content_types.CSV)
def test_default_input_fn_npy(tensor):
    stream = BytesIO()
    np.save(stream, tensor.cpu().numpy())
    deserialized_np_array = default_input_fn(stream.getvalue(), content_types.NPY)

    assert deserialized_np_array.is_cuda == torch.cuda.is_available()
    assert torch.equal(tensor, deserialized_np_array)
def test_default_input_fn_csv():
    array = [[1, 2, 3], [4, 5, 6]]
    str_io = StringIO()
    csv.writer(str_io, delimiter=',').writerows(array)

    deserialized_np_array = default_input_fn(str_io.getvalue(), content_types.CSV)

    tensor = torch.FloatTensor(array).to(device)
    assert torch.equal(tensor, deserialized_np_array)
    assert deserialized_np_array.is_cuda == torch.cuda.is_available()
def test_default_input_fn_bad_content_type():
    with pytest.raises(encoders.UnsupportedFormatError):
        default_input_fn('', 'application/not_supported')
def test_default_input_fn_json(tensor):
    json_data = json.dumps(tensor.cpu().numpy().tolist())
    deserialized_np_array = default_input_fn(json_data, content_types.JSON)

    assert deserialized_np_array.is_cuda == torch.cuda.is_available()
    assert torch.equal(tensor, deserialized_np_array)