Example #1
0
def test_leading_transpose_fails():
    """ Check that error is thrown if `perm` is not compatible with `a` """
    dims = [1, 2, 3, 4]
    a = tf.zeros(dims)

    with pytest.raises(ValueError):
        leading_transpose(a, [-1, -2])
Example #2
0
def test_leading_transpose():
    dims = [1, 2, 3, 4]
    a = tf.zeros(dims)
    b = leading_transpose(a, [..., -1, -2])
    c = leading_transpose(a, [-1, ..., -2])
    d = leading_transpose(a, [-1, -2, ...])
    e = leading_transpose(a, [3, 2, ...])
    f = leading_transpose(a, [3, -2, ...])

    assert len(a.shape) == len(b.shape) == len(c.shape) == len(d.shape)
    assert len(a.shape) == len(e.shape) == len(f.shape)
    assert b.shape[-2:] == [4, 3]
    assert c.shape[0] == 4 and c.shape[-1] == 3
    assert d.shape[:2] == [4, 3]
    assert d.shape == e.shape == f.shape
Example #3
0
 def compiled_wrapper():
     return leading_transpose(a, [..., -1, -2])