def test(self): b = a = jt.array(1) for i in range(10): b = b.clone() if i == 5: c = b b.sync() assert jt.number_of_lived_vars() == 11 c.stop_grad() assert jt.number_of_lived_vars() == 3
def test_var_holder(self): jt.clean() expect_error(lambda: jt.matmul(1,1)) expect_error(lambda: jt.matmul([1],[1])) expect_error(lambda: jt.matmul([[1]],[1])) self.assertEqual(jt.number_of_lived_vars(), 0) a = jt.matmul(jt.float32([[3]]), jt.float32([[4]])).data assert a.shape == (1,1) and a[0,0] == 12 a = np.array([[1, 0], [0, 1]]).astype("float32") b = np.array([[4, 1], [2, 2]]).astype("float32") c = np.matmul(a, b) jtc = jt.matmul(jt.array(a), jt.array(b)).data assert np.all(jtc == c)
def check(hv, lv, lo): import gc gc.collect() jt.graph_check() a, b, c = jt.number_of_hold_vars(), jt.number_of_lived_vars(), jt.number_of_lived_ops() assert (a,b,c)==(hv,lv,lo), (a, b, c, jt.dump_all_graphs().nodes_info)
def check(hv, lv, lo): self.assertEqual(jt.number_of_hold_vars(), hv) self.assertEqual(jt.number_of_lived_vars(), lv) self.assertEqual(jt.number_of_lived_ops(), lo)