Example #1
0
def check_cub_argsort(shape, dim, descending = False):
    with jt.log_capture_scope(
        log_silent=1,
        log_v=0, log_vprefix="op.cc=100"
    ) as raw_log:
        x = jt.random(shape)
        y, y_key = jt.argsort(x, dim=dim, descending=descending)
        v = []
        for i in range(len(shape)):
            if (i == dim):
                v.append(y)
            else:
                v.append(jt.index(shape, dim=i))
        yk = jt.reindex(x, v)
        yk_ = yk.data
        y_key_ = y_key.data
    logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "cub_argsort" + ".*)")
    assert len(logs)==1
    x__ = x.data
    if descending:
        x__ = -x__
    yk__ = np.sort(x__, axis=dim)
    if descending:
        yk__ = -yk__
    assert np.allclose(y_key_, yk__)
    assert np.allclose(yk_, yk__)
Example #2
0
def check_argsort(shape, dim, descending = False):
    x = jt.random(shape)
    y, y_key = jt.argsort(x, dim=dim, descending=descending)
    v = []
    for i in range(len(shape)):
        if (i == dim):
            v.append(y)
        else:
            v.append(jt.index(shape, dim=i))
    yk = jt.reindex(x, v)
    yk_ = yk.data
    y_key_ = y_key.data
    x__ = x.data
    if descending:
        x__ = -x__
    yk__ = np.sort(x__, axis=dim)
    if descending:
        yk__ = -yk__
    assert np.allclose(y_key_, yk__)
    assert np.allclose(yk_, yk__)
    def test_node_performance(self):
        mode = os.environ.get("test_node_performance")
        if mode==None or mode not in "12":
            return
        if mode=="1":
            bc = lambda x: jt.broadcast(x, [1,1,1,1],[0,1,2])
            rd = lambda x: jt.sum(x)
        else:
            bc = lambda x: jt.reindex(x, [1,1,1,1],["i0+i1+i2+i3"])
            rd = lambda x: jt.reindex_reduce(x, "add", [1], ["i0+i1+i2+i3"])
        if jt.compiler.is_debug: return
        def run():
            start_time = time.time()
            fop_num = 10000
            fop_input_num = (2, 3) # (i,j) -> [i,i+j] -> [2, 5]
            # fop_output_num = (1, 0) # [1,1]
            inner_op_num = (0, 3)
            fop_type_num = 63 # how many different fuse op
            input_queue_num = 15
            queue = [1.0]*(input_queue_num+1)
            x = get_xorshf96()
            rand = lambda x, l, r: l+((x())&r)
            ops = ["add", "subtract", "multiply", "divide"]
            get_op = lambda x: ops[(x())&3]
            for i in range(fop_num):
                prev = bc(queue[rand(x,0,input_queue_num)])
                y = get_xorshf96(x()&fop_type_num)
                inum = rand(y, *fop_input_num)
                q = [prev]
                for i in range(inum-1):
                    n = bc(queue[rand(x,0,input_queue_num)])
                    prev = jt.binary(prev, n, get_op(y))
                    q.append(prev)
                innum = rand(y,*inner_op_num)
                for _ in range(innum):
                    j = rand(y,0,len(q)-1)
                    n = q[j]
                    prev = jt.binary(prev, n, get_op(y))
                    q[j] = prev
                prev = rd(prev)
                queue[rand(x,0,input_queue_num)] = prev
            a = jt.array(0.0)
            for x in queue:
                a += x
            LOG.i("build graph", time.time()-start_time, jt.liveness_info().values())
            start_time = time.time()
            a.sync()
            LOG.i("execute", time.time()-start_time)
        # debug mode:  build(0.68), execute(0.44)
        # normal mode: build(0.56), execute(0.25)
        # cast opt:    build(0.50), execute(0.25)
        # dtype opt:   build(0.49), execute(0.25)
        # pyjt opt:    build(0.48), execute(0.25)
        # ns opt:      build(0.46), execute(0.24)
        # nv opt:      build(0.42), execute(0.23)
        # nv opt:      build(0.415),execute(0.225)
        # jit_key opt: build(0.415),execute(0.15)
        # jit_key opt: build(0.415),execute(0.11)
        # sv opt:      build(0.42), execute(0.12)
        # noded opt:   build(0.42), execute(0.10)

        # tcm opt:     build(0.40), execute(0.10)

        # mode2: reindex
        # jit_key opt: build(0.46),execute(0.12)
        # noded opt:   build(0.44),execute(0.11)
        # for i in range(20):
        #     run()
        for i in range(20):
            run()
        import gc
        gc.collect()
        run()