def test_batch_mesh_indexing(mode): items = [[1, False, False, False, True], [2, False, False, False, True]] tensors = [Tensor([[0, 2], [0, 2]]), Tensor([[0, 1, 2], [1, 2, 3]])] data = [Tensor(np.random.random((2, 3, 4))), Tensor(np.random.random((2, 2, 3)))] if mode == "get": op = builtin.BatchedMeshIndexing(items) data = data[:1] if mode == "set": op = builtin.BatchedSetMeshIndexing(items) if mode == "inc": op = builtin.BatchedIncrMeshIndexing(items) @trace(symbolic=True, capture_as_const=True) def fwd(*tensors): return apply(op, *tensors)[0] result = fwd(*data, *tensors) check_pygraph_dump(fwd, data + tensors, [result])
def batched_mesh_indexing(input, tuple_val): input, tensors, items = unpack_getitem(input, tuple_val) op = builtin.BatchedMeshIndexing(items) return invoke_op(op, (input, *tensors))