Exemple #1
0
    def test_add(self):
        jt.clean()

        def check(hv, lv, lo):
            self.assertEqual(jt.number_of_hold_vars(), hv)
            self.assertEqual(jt.number_of_lived_vars(), lv)
            self.assertEqual(jt.number_of_lived_ops(), lo)

        for i in range(8):
            check(0, 0, 0)
            a = jt.array(1.0).name('a').stop_fuse()
            b = (a + jt.array(1.0).name('t1').stop_fuse()).name('b')
            c = (b + jt.array(1.0).name('t2').stop_fuse()).name('c')
            check(3, 5, 5)
            graph = jt.dump_all_graphs()
            self.assertEqual(c.data, 3)
            check(3, 5, 2)
            graph = jt.dump_all_graphs()
            for node in graph.nodes_info:
                if node.startswith("Op"):
                    if 'add->' in node:
                        assert ':s0' in node, node
                    else:
                        assert ':s1' in node, node
                elif ',b,' in node:
                    # b has been fused
                    assert ':s0' in node, node
                else:
                    assert ':s1' in node
            if i & 1: del a
            if i & 2: del b
            if i & 4: del c

            if i == 0: check(3, 5, 2)
            elif i == 1: check(2, 5, 2)
            elif i == 2: check(2, 5, 2)
            elif i == 3: check(1, 1, 0)
            elif i == 4: check(2, 3, 1)
            elif i == 5: check(1, 3, 1)
            elif i == 6: check(1, 1, 0)
            elif i == 7: check(0, 0, 0)

            if not (i & 1): a.sync()
            if not (i & 2): b.sync()
            if not (i & 4): c.sync()

            if i == 0: check(3, 5, 2)
            elif i == 1: check(2, 3, 1)
            elif i == 2: check(2, 5, 2)
            elif i == 3: check(1, 1, 0)
            elif i == 4: check(2, 3, 1)
            elif i == 5: check(1, 1, 0)
            elif i == 6: check(1, 1, 0)

            if not (i & 1): del a
            if not (i & 2): del b
            if not (i & 4): del c
            check(0, 0, 0)
 def test_reshape(self):
     a = jt.random([123, 456, 789]).name("a")
     b = jt.reshape(a, [123 * 2, int(789 * 456 / 2)]).name("b")
     c = jt.reshape(b, [123 * 456 * 789]).name("c")
     d = jt.reshape(c, [2, int(123 / 3), 789, int(456 / 2), 3]).name("d")
     e = jt.reshape(d, [2, int(123 / 3), 789, -1, 3]).name("e")
     assert b.shape == [123 * 2, int(789 * 456 / 2)]
     assert c.shape == [123 * 456 * 789]
     assert d.shape == [2, int(123 / 3), 789, int(456 / 2), 3]
     assert e.shape == [2, int(123 / 3), 789, int(456 / 2), 3]
     a_mean = a.mean().data
     b_mean = b.mean().data
     c_mean = c.mean().data
     d_mean = d.mean().data
     e_mean = e.mean().data
     a = (a + 1).name("new_a")
     new_a_mean = a.mean().data
     new_b_mean = b.mean().data
     node_dict = get_info(jt.dump_all_graphs())
     assert check_equal(a_mean, b_mean), f"{a_mean} != {b_mean}"
     assert check_equal(a_mean, c_mean), f"{a_mean} != {c_mean}"
     assert check_equal(a_mean, d_mean), f"{a_mean} != {d_mean}"
     assert check_equal(a_mean, e_mean), f"{a_mean} != {e_mean}"
     assert check_equal(b_mean, new_b_mean), f"{b_mean} != {new_b_mean}"
     assert not check_equal(a_mean, new_a_mean), f"{a_mean} == {new_a_mean}"
     assert node_dict['a'] == node_dict['b']
     assert node_dict['a'] == node_dict['c']
     assert node_dict['a'] == node_dict['d']
     assert node_dict['a'] == node_dict['e']
 def check(bop_num):
     jt.clean()
     yield
     graph = jt.dump_all_graphs()
     bop = [ node for node in graph.nodes_info 
         if node.startswith("Op") and "broadcast_to" in node]
     assert len(bop)==bop_num, (len(bop), bop_num)
Exemple #4
0
def check_fused(dim):
    jt.clean()
    graph = jt.dump_all_graphs()
    fused = True
    has_v = False
    for node in graph.nodes_info:
        shape = node.split('[')[-1].split(',')
        ndim = len(shape)-1
        if ndim>dim:
            has_v = True
            if 's0' not in node:
                fused = False
    assert fused and has_v, graph.nodes_info
 def test_fuse_reduce2(self):
     size = 10
     a = jt.random([1]).broadcast([size]).name('a')
     # a.data
     b = a.sum().name('b')
     c = a.min().name('c')
     d = a.max().name('d')
     jt.fetch_sync([b,c,d])
     
     graph = jt.dump_all_graphs()
     node_a = [ node for node in graph.nodes_info if ",a," in node ]
     assert 's0' in node_a[0]
     
     v = a.data[0]
     assert np.allclose(v*10,b.data) and v==c.data and v==d.data, (v, b.data, c.data, d.data)
Exemple #6
0
 def test_longest_dis_fuse(self):
     x = jt.array(np.random.rand(1, 3, 224, 224).astype(np.float32))
     loss = jt.sum(resnet_fake(x))
     ps = jt.find_vars('resnet_fake')
     gs = jt.grad(loss, ps)
     jt.sync(gs)
     # assert not alloc big tensor
     g = jt.dump_all_graphs()
     for s in g.nodes_info:
         if not s.startswith("Var"):
             continue
         shape = s.split("[")[1].split("]")[0].split(",")
         ptr = s.split("(")[1].split(")")[0].split(",")[-1]
         if ptr != '0':
             assert len(shape) <= 5, s
    def test6(self):
        jt.clean()

        def check(hv, lv, lo):
            self.assertEqual(jt.number_of_hold_vars(), hv)
            self.assertEqual(jt.number_of_lived_vars(), lv)
            self.assertEqual(jt.number_of_lived_ops(), lo)

        check(0, 0, 0)
        a = jt.array(1.0).name('a').stop_fuse()
        b = (a + jt.array(1.0).name('t1').stop_fuse()).name('b')
        c = (b + jt.array(1.0).name('t2').stop_fuse()).name('c')
        check(3, 5, 5)
        graph = jt.dump_all_graphs()
        self.assertEqual(c.data, 3)
        check(3, 5, 2)
def check(hv, lv, lo):
    import gc
    gc.collect()
    jt.graph_check()
    a, b, c = jt.number_of_hold_vars(), jt.number_of_lived_vars(), jt.number_of_lived_ops()
    assert (a,b,c)==(hv,lv,lo), (a, b, c, jt.dump_all_graphs().nodes_info)
Exemple #9
0
def is_fused(x):
    x.name('_x')
    graph = jt.dump_all_graphs()
    node_a = [ node for node in graph.nodes_info if ",_x," in node ]
    return 's0' in node_a[0]