def test_leading_transpose_fails(): """ Check that error is thrown if `perm` is not compatible with `a` """ dims = [1, 2, 3, 4] a = tf.zeros(dims) with pytest.raises(ValueError): leading_transpose(a, [-1, -2])
def test_leading_transpose(): dims = [1, 2, 3, 4] a = tf.zeros(dims) b = leading_transpose(a, [..., -1, -2]) c = leading_transpose(a, [-1, ..., -2]) d = leading_transpose(a, [-1, -2, ...]) e = leading_transpose(a, [3, 2, ...]) f = leading_transpose(a, [3, -2, ...]) assert len(a.shape) == len(b.shape) == len(c.shape) == len(d.shape) assert len(a.shape) == len(e.shape) == len(f.shape) assert b.shape[-2:] == [4, 3] assert c.shape[0] == 4 and c.shape[-1] == 3 assert d.shape[:2] == [4, 3] assert d.shape == e.shape == f.shape
def compiled_wrapper(): return leading_transpose(a, [..., -1, -2])