Exemplo n.º 1
0
def test_export_stream():
    model = nn.Sequential(nn.Linear(5, 10, bias=False))
    x = torch.zeros((2, 5))

    bytesio = io.BytesIO()
    assert len(bytesio.getvalue()) == 0
    out = export(model, x, bytesio, return_output=True)

    assert len(bytesio.getvalue()) > 0
    expected_out = torch.zeros((2, 10))  # check only shape size
    np.testing.assert_allclose(out.detach().cpu().numpy(),
                               expected_out.detach().cpu().numpy())
Exemplo n.º 2
0
def test_export_filename():
    model = nn.Sequential(nn.Linear(5, 10, bias=False))
    x = torch.zeros((2, 5))

    output_dir = _get_output_dir('export_filename')
    model_path = os.path.join(output_dir, 'model.onnx')

    with pytest.warns(UserWarning):
        out = export(model, x, model_path, return_output=True)

    assert os.path.isfile(model_path)
    expected_out = torch.zeros((2, 10))  # check only shape size
    np.testing.assert_allclose(out.detach().cpu().numpy(),
                               expected_out.detach().cpu().numpy())