Esempio n. 1
0
def setitem(x, slices, value):
    if isinstance(slices, jt.Var) and slices.dtype == "bool":
        mask = jt.broadcast(slices, x)
        value = jt.broadcast(value, x)
        return mask.ternary(value, mask)
    if isinstance(slices, list):
        slices = tuple(slices)
    return x.assign(x.setitem(slices, value))
Esempio n. 2
0
def setitem(x, slices, value):
    if isinstance(slices, jt.Var) and slices.dtype == "bool":
        mask = jt.broadcast(slices, x)
        value = jt.broadcast(value, x)
        return x.assign(mask.ternary(value, x))
    if isinstance(slices, Sequence):
        ss = []
        for s in slices:
            if isinstance(s, jt.Var) and s.dtype == "bool":
                ss.extend(s.where())
            else:
                ss.append(s)
        slices = tuple(ss)
    return x.assign(x.setitem(slices, value))
Esempio n. 3
0
def setitem(x, slices, value):
    reindex_args = slice_var_index(x, slices)
    reindex_reduce_args = (x.shape, reindex_args[1]) + reindex_args[3:]
    xslice = x.stop_fuse().reindex(*reindex_args).stop_fuse()
    value = jt.broadcast(value, xslice)
    one = jt.broadcast(1, xslice)
    if not isinstance(reindex_args[0][0], jt.Var):
        reindex_args = (x.shape,) + reindex_args[1:]
    mask = one.reindex_reduce("add", *reindex_reduce_args)
    data = value.reindex_reduce("add", *reindex_reduce_args)
    # Stop fuse both input and output, prevent recompile
    out = mask.ternary(data, x).stop_fuse()
    x.assign(out)
    return x
    def test_node_performance(self):
        mode = os.environ.get("test_node_performance")
        if mode==None or mode not in "12":
            return
        if mode=="1":
            bc = lambda x: jt.broadcast(x, [1,1,1,1],[0,1,2])
            rd = lambda x: jt.sum(x)
        else:
            bc = lambda x: jt.reindex(x, [1,1,1,1],["i0+i1+i2+i3"])
            rd = lambda x: jt.reindex_reduce(x, "add", [1], ["i0+i1+i2+i3"])
        if jt.compiler.is_debug: return
        def run():
            start_time = time.time()
            fop_num = 10000
            fop_input_num = (2, 3) # (i,j) -> [i,i+j] -> [2, 5]
            # fop_output_num = (1, 0) # [1,1]
            inner_op_num = (0, 3)
            fop_type_num = 63 # how many different fuse op
            input_queue_num = 15
            queue = [1.0]*(input_queue_num+1)
            x = get_xorshf96()
            rand = lambda x, l, r: l+((x())&r)
            ops = ["add", "subtract", "multiply", "divide"]
            get_op = lambda x: ops[(x())&3]
            for i in range(fop_num):
                prev = bc(queue[rand(x,0,input_queue_num)])
                y = get_xorshf96(x()&fop_type_num)
                inum = rand(y, *fop_input_num)
                q = [prev]
                for i in range(inum-1):
                    n = bc(queue[rand(x,0,input_queue_num)])
                    prev = jt.binary(prev, n, get_op(y))
                    q.append(prev)
                innum = rand(y,*inner_op_num)
                for _ in range(innum):
                    j = rand(y,0,len(q)-1)
                    n = q[j]
                    prev = jt.binary(prev, n, get_op(y))
                    q[j] = prev
                prev = rd(prev)
                queue[rand(x,0,input_queue_num)] = prev
            a = jt.array(0.0)
            for x in queue:
                a += x
            LOG.i("build graph", time.time()-start_time, jt.liveness_info().values())
            start_time = time.time()
            a.sync()
            LOG.i("execute", time.time()-start_time)
        # debug mode:  build(0.68), execute(0.44)
        # normal mode: build(0.56), execute(0.25)
        # cast opt:    build(0.50), execute(0.25)
        # dtype opt:   build(0.49), execute(0.25)
        # pyjt opt:    build(0.48), execute(0.25)
        # ns opt:      build(0.46), execute(0.24)
        # nv opt:      build(0.42), execute(0.23)
        # nv opt:      build(0.415),execute(0.225)
        # jit_key opt: build(0.415),execute(0.15)
        # jit_key opt: build(0.415),execute(0.11)
        # sv opt:      build(0.42), execute(0.12)
        # noded opt:   build(0.42), execute(0.10)

        # tcm opt:     build(0.40), execute(0.10)

        # mode2: reindex
        # jit_key opt: build(0.46),execute(0.12)
        # noded opt:   build(0.44),execute(0.11)
        # for i in range(20):
        #     run()
        for i in range(20):
            run()
        import gc
        gc.collect()
        run()