def test_add(self): jt.clean() def check(hv, lv, lo): self.assertEqual(jt.number_of_hold_vars(), hv) self.assertEqual(jt.number_of_lived_vars(), lv) self.assertEqual(jt.number_of_lived_ops(), lo) for i in range(8): check(0, 0, 0) a = jt.array(1.0).name('a').stop_fuse() b = (a + jt.array(1.0).name('t1').stop_fuse()).name('b') c = (b + jt.array(1.0).name('t2').stop_fuse()).name('c') check(3, 5, 5) graph = jt.dump_all_graphs() self.assertEqual(c.data, 3) check(3, 5, 2) graph = jt.dump_all_graphs() for node in graph.nodes_info: if node.startswith("Op"): if 'add->' in node: assert ':s0' in node, node else: assert ':s1' in node, node elif ',b,' in node: # b has been fused assert ':s0' in node, node else: assert ':s1' in node if i & 1: del a if i & 2: del b if i & 4: del c if i == 0: check(3, 5, 2) elif i == 1: check(2, 5, 2) elif i == 2: check(2, 5, 2) elif i == 3: check(1, 1, 0) elif i == 4: check(2, 3, 1) elif i == 5: check(1, 3, 1) elif i == 6: check(1, 1, 0) elif i == 7: check(0, 0, 0) if not (i & 1): a.sync() if not (i & 2): b.sync() if not (i & 4): c.sync() if i == 0: check(3, 5, 2) elif i == 1: check(2, 3, 1) elif i == 2: check(2, 5, 2) elif i == 3: check(1, 1, 0) elif i == 4: check(2, 3, 1) elif i == 5: check(1, 1, 0) elif i == 6: check(1, 1, 0) if not (i & 1): del a if not (i & 2): del b if not (i & 4): del c check(0, 0, 0)
def test_reshape(self): a = jt.random([123, 456, 789]).name("a") b = jt.reshape(a, [123 * 2, int(789 * 456 / 2)]).name("b") c = jt.reshape(b, [123 * 456 * 789]).name("c") d = jt.reshape(c, [2, int(123 / 3), 789, int(456 / 2), 3]).name("d") e = jt.reshape(d, [2, int(123 / 3), 789, -1, 3]).name("e") assert b.shape == [123 * 2, int(789 * 456 / 2)] assert c.shape == [123 * 456 * 789] assert d.shape == [2, int(123 / 3), 789, int(456 / 2), 3] assert e.shape == [2, int(123 / 3), 789, int(456 / 2), 3] a_mean = a.mean().data b_mean = b.mean().data c_mean = c.mean().data d_mean = d.mean().data e_mean = e.mean().data a = (a + 1).name("new_a") new_a_mean = a.mean().data new_b_mean = b.mean().data node_dict = get_info(jt.dump_all_graphs()) assert check_equal(a_mean, b_mean), f"{a_mean} != {b_mean}" assert check_equal(a_mean, c_mean), f"{a_mean} != {c_mean}" assert check_equal(a_mean, d_mean), f"{a_mean} != {d_mean}" assert check_equal(a_mean, e_mean), f"{a_mean} != {e_mean}" assert check_equal(b_mean, new_b_mean), f"{b_mean} != {new_b_mean}" assert not check_equal(a_mean, new_a_mean), f"{a_mean} == {new_a_mean}" assert node_dict['a'] == node_dict['b'] assert node_dict['a'] == node_dict['c'] assert node_dict['a'] == node_dict['d'] assert node_dict['a'] == node_dict['e']
def check(bop_num): jt.clean() yield graph = jt.dump_all_graphs() bop = [ node for node in graph.nodes_info if node.startswith("Op") and "broadcast_to" in node] assert len(bop)==bop_num, (len(bop), bop_num)
def check_fused(dim): jt.clean() graph = jt.dump_all_graphs() fused = True has_v = False for node in graph.nodes_info: shape = node.split('[')[-1].split(',') ndim = len(shape)-1 if ndim>dim: has_v = True if 's0' not in node: fused = False assert fused and has_v, graph.nodes_info
def test_fuse_reduce2(self): size = 10 a = jt.random([1]).broadcast([size]).name('a') # a.data b = a.sum().name('b') c = a.min().name('c') d = a.max().name('d') jt.fetch_sync([b,c,d]) graph = jt.dump_all_graphs() node_a = [ node for node in graph.nodes_info if ",a," in node ] assert 's0' in node_a[0] v = a.data[0] assert np.allclose(v*10,b.data) and v==c.data and v==d.data, (v, b.data, c.data, d.data)
def test_longest_dis_fuse(self): x = jt.array(np.random.rand(1, 3, 224, 224).astype(np.float32)) loss = jt.sum(resnet_fake(x)) ps = jt.find_vars('resnet_fake') gs = jt.grad(loss, ps) jt.sync(gs) # assert not alloc big tensor g = jt.dump_all_graphs() for s in g.nodes_info: if not s.startswith("Var"): continue shape = s.split("[")[1].split("]")[0].split(",") ptr = s.split("(")[1].split(")")[0].split(",")[-1] if ptr != '0': assert len(shape) <= 5, s
def test6(self): jt.clean() def check(hv, lv, lo): self.assertEqual(jt.number_of_hold_vars(), hv) self.assertEqual(jt.number_of_lived_vars(), lv) self.assertEqual(jt.number_of_lived_ops(), lo) check(0, 0, 0) a = jt.array(1.0).name('a').stop_fuse() b = (a + jt.array(1.0).name('t1').stop_fuse()).name('b') c = (b + jt.array(1.0).name('t2').stop_fuse()).name('c') check(3, 5, 5) graph = jt.dump_all_graphs() self.assertEqual(c.data, 3) check(3, 5, 2)
def check(hv, lv, lo): import gc gc.collect() jt.graph_check() a, b, c = jt.number_of_hold_vars(), jt.number_of_lived_vars(), jt.number_of_lived_ops() assert (a,b,c)==(hv,lv,lo), (a, b, c, jt.dump_all_graphs().nodes_info)
def is_fused(x): x.name('_x') graph = jt.dump_all_graphs() node_a = [ node for node in graph.nodes_info if ",_x," in node ] return 's0' in node_a[0]