def check_cub_argsort(shape, dim, descending = False): with jt.log_capture_scope( log_silent=1, log_v=0, log_vprefix="op.cc=100" ) as raw_log: x = jt.random(shape) y, y_key = jt.argsort(x, dim=dim, descending=descending) v = [] for i in range(len(shape)): if (i == dim): v.append(y) else: v.append(jt.index(shape, dim=i)) yk = jt.reindex(x, v) yk_ = yk.data y_key_ = y_key.data logs = find_log_with_re(raw_log, "(Jit op key (not )?found: " + "cub_argsort" + ".*)") assert len(logs)==1 x__ = x.data if descending: x__ = -x__ yk__ = np.sort(x__, axis=dim) if descending: yk__ = -yk__ assert np.allclose(y_key_, yk__) assert np.allclose(yk_, yk__)
def check_argsort(shape, dim, descending = False): x = jt.random(shape) y, y_key = jt.argsort(x, dim=dim, descending=descending) v = [] for i in range(len(shape)): if (i == dim): v.append(y) else: v.append(jt.index(shape, dim=i)) yk = jt.reindex(x, v) yk_ = yk.data y_key_ = y_key.data x__ = x.data if descending: x__ = -x__ yk__ = np.sort(x__, axis=dim) if descending: yk__ = -yk__ assert np.allclose(y_key_, yk__) assert np.allclose(yk_, yk__)
def test_node_performance(self): mode = os.environ.get("test_node_performance") if mode==None or mode not in "12": return if mode=="1": bc = lambda x: jt.broadcast(x, [1,1,1,1],[0,1,2]) rd = lambda x: jt.sum(x) else: bc = lambda x: jt.reindex(x, [1,1,1,1],["i0+i1+i2+i3"]) rd = lambda x: jt.reindex_reduce(x, "add", [1], ["i0+i1+i2+i3"]) if jt.compiler.is_debug: return def run(): start_time = time.time() fop_num = 10000 fop_input_num = (2, 3) # (i,j) -> [i,i+j] -> [2, 5] # fop_output_num = (1, 0) # [1,1] inner_op_num = (0, 3) fop_type_num = 63 # how many different fuse op input_queue_num = 15 queue = [1.0]*(input_queue_num+1) x = get_xorshf96() rand = lambda x, l, r: l+((x())&r) ops = ["add", "subtract", "multiply", "divide"] get_op = lambda x: ops[(x())&3] for i in range(fop_num): prev = bc(queue[rand(x,0,input_queue_num)]) y = get_xorshf96(x()&fop_type_num) inum = rand(y, *fop_input_num) q = [prev] for i in range(inum-1): n = bc(queue[rand(x,0,input_queue_num)]) prev = jt.binary(prev, n, get_op(y)) q.append(prev) innum = rand(y,*inner_op_num) for _ in range(innum): j = rand(y,0,len(q)-1) n = q[j] prev = jt.binary(prev, n, get_op(y)) q[j] = prev prev = rd(prev) queue[rand(x,0,input_queue_num)] = prev a = jt.array(0.0) for x in queue: a += x LOG.i("build graph", time.time()-start_time, jt.liveness_info().values()) start_time = time.time() a.sync() LOG.i("execute", time.time()-start_time) # debug mode: build(0.68), execute(0.44) # normal mode: build(0.56), execute(0.25) # cast opt: build(0.50), execute(0.25) # dtype opt: build(0.49), execute(0.25) # pyjt opt: build(0.48), execute(0.25) # ns opt: build(0.46), execute(0.24) # nv opt: build(0.42), execute(0.23) # nv opt: build(0.415),execute(0.225) # jit_key opt: build(0.415),execute(0.15) # jit_key opt: build(0.415),execute(0.11) # sv opt: build(0.42), execute(0.12) # noded opt: build(0.42), execute(0.10) # tcm opt: build(0.40), execute(0.10) # mode2: reindex # jit_key opt: build(0.46),execute(0.12) # noded opt: build(0.44),execute(0.11) # for i in range(20): # run() for i in range(20): run() import gc gc.collect() run()