def test_mesh_indexing(mode): items = [[0, True, True, True, False], [1, False, False, False, True]] tensors = [Tensor(0), Tensor(5), Tensor(2), Tensor([1, 3])] data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 2)))] if mode == "get": op = builtin.IndexingMultiAxisVec(items) data = data[:1] if mode == "set": op = builtin.IndexingSetMultiAxisVec(items) if mode == "inc": op = builtin.IndexingIncrMultiAxisVec(items) @trace(symbolic=True, capture_as_const=True) def fwd(*tensors): return apply(op, *tensors)[0] result = fwd(*data, *tensors) check_pygraph_dump(fwd, data + tensors, [result])
def incr_advance_indexing(input, value, tuple_val): input, tensors, items = unpack_getitem(input, tuple_val) op = builtin.IndexingIncrMultiAxisVec(items) return invoke_op(op, (input, value, *tensors))