def test_state_create(self): x = State() x.set(k=torch.randn(3, 4)) xsub = State() xsub.set(v=torch.rand(3, 2)) x.set(s=xsub) x.set(l=np.asarray(["sqdf", "qdsf", "qdsf"])) print(x) print(x._schema_keys) print(x.k) print(x.s.v) print(x.l)
def test_state_getitem(self): x = State() x.set(k=torch.randn(5, 4)) xsub = State() xsub.set(v=torch.rand(5, 2)) x.set(s=xsub) x.set(l=np.asarray(["a", "b", "c", "d", "e"])) y = x[2] print(x[2].k) print(x[2].l) print(x[2].s.v) print(x[:2].k) print(x[:2].l) print(x[:2].s.v)
def test_state_setitem(self): x = State() x.set(k=torch.randn(5, 4)) xsub = State() xsub.set(v=torch.rand(5, 2)) x.set(s=xsub) x.set(l=np.asarray(["sqdf", "qdsf", "qdsf", "a", "b"])) y = State() y.set(k=torch.ones(2, 4)) ysub = State() ysub.set(v=torch.ones(2, 2)) y.set(s=ysub) y.set(l=np.asarray(["o", "o"])) x[1:3] = y print(x.k) print(x.s.v) print(x.l)
def test_state_merge(self): x = State() x.set(k=torch.randn(3, 4)) xsub = State() xsub.set(v=torch.rand(3, 2)) x.set(s=xsub) x.set(l=np.asarray(["sqdf", "qdsf", "qdsf"])) y = State() y.set(k=torch.randn(2, 4)) ysub = State() ysub.set(v=torch.rand(2, 2)) y.set(s=ysub) y.set(l=np.asarray(["b", "a"])) z = State.merge([x, y]) print(z._schema_keys) print(z.k) print(x.k) print(y.k) print(z.s.v) print(z.l)