def test_default_output_fn_gpu():
    tensor_gpu = torch.LongTensor([[1, 2, 3], [4, 5, 6]]).cuda()

    output = default_output_fn(tensor_gpu, content_types.CSV)

    assert '1,2,3\n4,5,6\n' in output.get_data(as_text=True)
    assert content_types.CSV in output.content_type
def test_default_output_fn_npy(tensor):
    output = default_output_fn(tensor, content_types.NPY)

    stream = BytesIO()
    np.save(stream, tensor.cpu().numpy())

    assert stream.getvalue() in output.get_data(as_text=False)
    assert content_types.NPY in output.content_type
def test_default_output_fn_bad_accept():
    with pytest.raises(encoders.UnsupportedFormatError):
        default_output_fn('', 'application/not_supported')
def test_default_output_fn_csv_float():
    tensor = torch.FloatTensor([[1, 2, 3], [4, 5, 6]])
    output = default_output_fn(tensor, content_types.CSV)

    assert '1.0,2.0,3.0\n4.0,5.0,6.0\n' in output.get_data(as_text=True)
    assert content_types.CSV in output.content_type
def test_default_output_fn_json(tensor):
    output = default_output_fn(tensor, content_types.JSON)

    assert json.dumps(tensor.cpu().numpy().tolist()) in output.get_data(as_text=True)
    assert content_types.JSON in output.content_type
Ejemplo n.º 6
0
def test_default_output_fn_csv_long():
    tensor = torch.LongTensor([[1, 2, 3], [4, 5, 6]])
    output = default_output_fn(tensor, content_types.CSV)

    assert '1,2,3\n4,5,6\n' in output.get_data(as_text=True)
    assert content_types.CSV == output.mimetype