def setitem(x, slices, value): if isinstance(slices, jt.Var) and slices.dtype == "bool": mask = jt.broadcast(slices, x) value = jt.broadcast(value, x) return mask.ternary(value, mask) if isinstance(slices, list): slices = tuple(slices) return x.assign(x.setitem(slices, value))
def setitem(x, slices, value): if isinstance(slices, jt.Var) and slices.dtype == "bool": mask = jt.broadcast(slices, x) value = jt.broadcast(value, x) return x.assign(mask.ternary(value, x)) if isinstance(slices, Sequence): ss = [] for s in slices: if isinstance(s, jt.Var) and s.dtype == "bool": ss.extend(s.where()) else: ss.append(s) slices = tuple(ss) return x.assign(x.setitem(slices, value))
def setitem(x, slices, value): reindex_args = slice_var_index(x, slices) reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:] xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse() value = jt.broadcast(value, xslice) one = jt.broadcast(1, xslice) if not isinstance(reindex_args[0][0], jt.Var): reindex_args = (x.shape,) + reindex_args[1:] mask = one.reindex_reduce("add", *reindex_reduce_args) data = value.reindex_reduce("add", *reindex_reduce_args) # Stop fuse both input and output, prevent recompile out = mask.ternary(data, x).stop_fuse() x.assign(out) return x
def test_node_performance(self): mode = os.environ.get("test_node_performance") if mode==None or mode not in "12": return if mode=="1": bc = lambda x: jt.broadcast(x, [1,1,1,1],[0,1,2]) rd = lambda x: jt.sum(x) else: bc = lambda x: jt.reindex(x, [1,1,1,1],["i0+i1+i2+i3"]) rd = lambda x: jt.reindex_reduce(x, "add", [1], ["i0+i1+i2+i3"]) if jt.compiler.is_debug: return def run(): start_time = time.time() fop_num = 10000 fop_input_num = (2, 3) # (i,j) -> [i,i+j] -> [2, 5] # fop_output_num = (1, 0) # [1,1] inner_op_num = (0, 3) fop_type_num = 63 # how many different fuse op input_queue_num = 15 queue = [1.0]*(input_queue_num+1) x = get_xorshf96() rand = lambda x, l, r: l+((x())&r) ops = ["add", "subtract", "multiply", "divide"] get_op = lambda x: ops[(x())&3] for i in range(fop_num): prev = bc(queue[rand(x,0,input_queue_num)]) y = get_xorshf96(x()&fop_type_num) inum = rand(y, *fop_input_num) q = [prev] for i in range(inum-1): n = bc(queue[rand(x,0,input_queue_num)]) prev = jt.binary(prev, n, get_op(y)) q.append(prev) innum = rand(y,*inner_op_num) for _ in range(innum): j = rand(y,0,len(q)-1) n = q[j] prev = jt.binary(prev, n, get_op(y)) q[j] = prev prev = rd(prev) queue[rand(x,0,input_queue_num)] = prev a = jt.array(0.0) for x in queue: a += x LOG.i("build graph", time.time()-start_time, jt.liveness_info().values()) start_time = time.time() a.sync() LOG.i("execute", time.time()-start_time) # debug mode: build(0.68), execute(0.44) # normal mode: build(0.56), execute(0.25) # cast opt: build(0.50), execute(0.25) # dtype opt: build(0.49), execute(0.25) # pyjt opt: build(0.48), execute(0.25) # ns opt: build(0.46), execute(0.24) # nv opt: build(0.42), execute(0.23) # nv opt: build(0.415),execute(0.225) # jit_key opt: build(0.415),execute(0.15) # jit_key opt: build(0.415),execute(0.11) # sv opt: build(0.42), execute(0.12) # noded opt: build(0.42), execute(0.10) # tcm opt: build(0.40), execute(0.10) # mode2: reindex # jit_key opt: build(0.46),execute(0.12) # noded opt: build(0.44),execute(0.11) # for i in range(20): # run() for i in range(20): run() import gc gc.collect() run()