def testing(): with markov(): v1 = to_data( Tensor(jnp.ones(2), OrderedDict([("1", Bint[2])]), "real")) print(1, v1.shape) # shapes should alternate assert v1.shape == (2, ) with markov(): v2 = to_data( Tensor(jnp.ones(2), OrderedDict([("2", Bint[2])]), "real")) print(2, v2.shape) # shapes should alternate assert v2.shape == (2, 1) with markov(): v3 = to_data( Tensor(jnp.ones(2), OrderedDict([("3", Bint[2])]), "real")) print(3, v3.shape) # shapes should alternate assert v3.shape == (2, ) with markov(): v4 = to_data( Tensor(jnp.ones(2), OrderedDict([("4", Bint[2])]), "real")) print(4, v4.shape) # shapes should alternate assert v4.shape == (2, 1)
def testing(): for i in markov(range(5)): v1 = to_data(Tensor(jnp.ones(2), OrderedDict([(str(i), bint(2))]), 'real')) v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([('a', bint(2))]), 'real')) fv1 = to_funsor(v1, reals()) fv2 = to_funsor(v2, reals()) print(i, v1.shape) # shapes should alternate if i % 2 == 0: assert v1.shape == (2,) else: assert v1.shape == (2, 1, 1) assert v2.shape == (2, 1) print(i, fv1.inputs) print('a', v2.shape) # shapes should stay the same print('a', fv2.inputs)
def testing(): for i in markov(range(12)): if i % 4 == 0: v2 = to_data(Tensor(jnp.zeros(2), OrderedDict([('a', bint(2))]), 'real')) fv2 = to_funsor(v2, reals()) assert v2.shape == (2,) print('a', v2.shape) print('a', fv2.inputs)
def testing(): for i in markov(range(5)): v1 = to_data( Tensor(jnp.ones(2), OrderedDict([(str(i), Bint[2])]), "real")) v2 = to_data( Tensor(jnp.zeros(2), OrderedDict([("a", Bint[2])]), "real")) fv1 = to_funsor(v1, Real) fv2 = to_funsor(v2, Real) print(i, v1.shape) # shapes should alternate if i % 2 == 0: assert v1.shape == (2, ) else: assert v1.shape == (2, 1, 1) assert v2.shape == (2, 1) print(i, fv1.inputs) print("a", v2.shape) # shapes should stay the same print("a", fv2.inputs)
def testing(): for i in markov(range(12)): if i % 4 == 0: v2 = to_data( Tensor(jnp.zeros(2), OrderedDict([("a", Bint[2])]), "real")) fv2 = to_funsor(v2, Real) assert v2.shape == (2, ) print("a", v2.shape) print("a", fv2.inputs)