def test_call_torch(): np.testing.assert_almost_equal( call_torch(torch.sqrt, np.asarray([1, 4, 9], dtype="float")), [1, 2, 3]) np.testing.assert_almost_equal( call_torch( torch.add, np.asarray([1, 2, 3], dtype="float"), np.asarray([4, 5, 6], dtype="float"), ), [5, 7, 9], )
def test_call_torch_structured(): a, b = call_torch( lambda t: (t[0] + t[1], t[1] - t[0]), (np.asarray([1, 2, 3], dtype="float"), np.asarray([4, 5, 6], dtype="float")), ) np.testing.assert_almost_equal(a, [5, 7, 9]) np.testing.assert_almost_equal(b, [3, 3, 3])
def test_call_torch_batched(): np.testing.assert_almost_equal( call_torch(torch.sqrt, np.arange(1024).astype("float"), batch_size=128), np.arange(1024)**0.5, )