def test_addition(): msg = "" try: with tsensor.clarify(): q = b + x + 3 except tf.errors.InvalidArgumentError as iae: msg = iae.message expected = "Incompatible shapes: [2,1] vs. [3,1] [Op:AddV2]\n"+\ "Cause: + on tensor operand b w/shape (2, 1) and operand x w/shape (3, 1)" assert msg == expected
def test_fft(): "Test a library function that doesn't have a shape related message in the exception." x = np.exp(2j * np.pi * np.arange(8) / 8) msg = "" try: with tsensor.clarify(): y = jnp.fft.fft(x, norm="something weird") except BaseException as e: msg = e.args[0] print(msg) expected = 'jax.numpy.fft.fft only supports norm=None, got something weird\n'+\ 'Cause: jnp.fft.fft(x,norm="something weird") tensor arg x w/shape (8,)' assert msg == expected
def test_mmul(): W = jnp.array([[1, 2], [3, 4]]) b = jnp.array([9, 10, 11]) msg = "" try: with tsensor.clarify(): y = W @ b except TypeError as e: msg = e.args[0] expected = "dot_general requires contracting dimensions to have the same shape, got [2] and [3].\n"+\ "Cause: @ on tensor operand W w/shape (2, 2) and operand b w/shape (3,)" assert msg == expected
def test_scalar_arg(): size = 5000 x = np.random.normal(size=(size, size)).astype(np.float32) msg = "" try: with tsensor.clarify(): z = jnp.dot(x, "foo") except TypeError as e: msg = e.args[0] expected = 'data type "foo" not understood\n'+\ 'Cause: jnp.dot(x,"foo") tensor arg x w/shape (5000, 5000)' assert msg == expected
def test_dot(): size = 5000 x = np.random.normal(size=(size, size)).astype(np.float32) y = np.random.normal(size=(5, 1)).astype(np.float32) msg = "" try: with tsensor.clarify(): z = jnp.dot(x, y).block_until_ready() except TypeError as e: msg = e.args[0] expected = "Incompatible shapes for dot: got (5000, 5000) and (5, 1).\n"+\ "Cause: jnp.dot(x,y) tensor arg x w/shape (5000, 5000), arg y w/shape (5, 1)" assert msg == expected
def B(): with tsensor.clarify(): A()
def A(): with tsensor.clarify(): f()
W @ np.dot(b, b) + np.eye(2, 2) @ x + z # W[33, 33] = 3 b = np.abs(W @ b + x) def g(): W = torch.tensor([[1, 2], [3, 4]]) b = torch.tensor([9, 10]).reshape(2, 1) x = torch.tensor([4, 5]).reshape(2, 1) z = torch.tensor([1, 2, 3]) # z + z + W @ z # W @ z torch.dot(b, 3) W @ torch.dot(b, b) + torch.eye(2, 2) @ x + z # W[33, 33] = 3 b = torch.abs(W @ b + x) # tr = Tracer() # sys.settrace(tr.listener) # frame = sys._getframe() # frame.f_trace = tr.listener def foo(): g() with tsensor.clarify(): g()