コード例 #1
0
def test_addition():
    msg = ""
    try:
        with tsensor.clarify():
            q = b + x + 3
    except tf.errors.InvalidArgumentError as iae:
        msg = iae.message

    expected = "Incompatible shapes: [2,1] vs. [3,1] [Op:AddV2]\n"+\
               "Cause: + on tensor operand b w/shape (2, 1) and operand x w/shape (3, 1)"
    assert msg == expected
コード例 #2
0
def test_fft():
    "Test a library function that doesn't have a shape related message in the exception."
    x = np.exp(2j * np.pi * np.arange(8) / 8)
    msg = ""
    try:
        with tsensor.clarify():
            y = jnp.fft.fft(x, norm="something weird")
    except BaseException as e:
        msg = e.args[0]
        print(msg)

    expected = 'jax.numpy.fft.fft only supports norm=None, got something weird\n'+\
               'Cause: jnp.fft.fft(x,norm="something weird") tensor arg x w/shape (8,)'
    assert msg == expected
コード例 #3
0
def test_mmul():
    W = jnp.array([[1, 2], [3, 4]])
    b = jnp.array([9, 10, 11])

    msg = ""
    try:
        with tsensor.clarify():
            y = W @ b
    except TypeError as e:
        msg = e.args[0]

    expected = "dot_general requires contracting dimensions to have the same shape, got [2] and [3].\n"+\
               "Cause: @ on tensor operand W w/shape (2, 2) and operand b w/shape (3,)"
    assert msg == expected
コード例 #4
0
def test_scalar_arg():
    size = 5000
    x = np.random.normal(size=(size, size)).astype(np.float32)

    msg = ""
    try:
        with tsensor.clarify():
            z = jnp.dot(x, "foo")
    except TypeError as e:
        msg = e.args[0]

    expected = 'data type "foo" not understood\n'+\
               'Cause: jnp.dot(x,"foo") tensor arg x w/shape (5000, 5000)'
    assert msg == expected
コード例 #5
0
def test_dot():
    size = 5000
    x = np.random.normal(size=(size, size)).astype(np.float32)
    y = np.random.normal(size=(5, 1)).astype(np.float32)

    msg = ""
    try:
        with tsensor.clarify():
            z = jnp.dot(x, y).block_until_ready()
    except TypeError as e:
        msg = e.args[0]

    expected = "Incompatible shapes for dot: got (5000, 5000) and (5, 1).\n"+\
               "Cause: jnp.dot(x,y) tensor arg x w/shape (5000, 5000), arg y w/shape (5, 1)"
    assert msg == expected
コード例 #6
0
def B():
    with tsensor.clarify():
        A()
コード例 #7
0
def A():
    with tsensor.clarify():
        f()
コード例 #8
0
    W @ np.dot(b, b) + np.eye(2, 2) @ x + z
    # W[33, 33] = 3
    b = np.abs(W @ b + x)


def g():
    W = torch.tensor([[1, 2], [3, 4]])
    b = torch.tensor([9, 10]).reshape(2, 1)
    x = torch.tensor([4, 5]).reshape(2, 1)
    z = torch.tensor([1, 2, 3])
    # z + z + W @ z
    # W @ z
    torch.dot(b, 3)
    W @ torch.dot(b, b) + torch.eye(2, 2) @ x + z
    # W[33, 33] = 3
    b = torch.abs(W @ b + x)


# tr = Tracer()
# sys.settrace(tr.listener)
# frame = sys._getframe()
# frame.f_trace = tr.listener


def foo():
    g()


with tsensor.clarify():
    g()