def check_isabstract(constructor, *args): # Regular application should not give anything abstract. assert not B.isabstract(constructor(*args)) # Force each of the arguments to be abstract and check the result. for i in range(len(args)): tracked = [] @B.jit def f(x): mat = constructor(*(x if j == i else args[j] for j in range(len(args)))) tracked.append(B.isabstract(mat)) return B.sum(mat) f(jnp.array(B.dense(args[i]))) # First run should be concrete: populate control flow cache. assert not tracked[0] # Second run should be abstract. assert tracked[1] # And that should be all. assert len(tracked) == 2
def f(x): mat = constructor(*(x if j == i else args[j] for j in range(len(args)))) tracked.append(B.isabstract(mat)) return B.sum(mat)
def isabstract(a: Kronecker): return B.isabstract(a.left) or B.isabstract(a.right)
def isabstract(a: LowRank): return B.isabstract(a.left) or B.isabstract(a.middle) or B.isabstract( a.right)
def isabstract(a: Woodbury): return B.isabstract(a.diag) or B.isabstract(a.lr)
def isabstract(a: Constant): return B.isabstract(a.const)
def isabstract(a: Diagonal): return B.isabstract(a.diag)
def isabstract(a: Union[Dense, LowerTriangular, UpperTriangular]): return B.isabstract(a.mat)
def f(x): tracked.append(B.isabstract(x)) return B.sum(x)
def test_isabstract_false(check_lazy_shapes): for a in Tensor().forms(): assert not B.isabstract(a)