Beispiel #1
0
def test_batch_mesh_indexing(mode):
    items = [[1, False, False, False, True], [2, False, False, False, True]]
    tensors = [Tensor([[0, 2], [0, 2]]), Tensor([[0, 1, 2], [1, 2, 3]])]
    data = [Tensor(np.random.random((2, 3, 4))), Tensor(np.random.random((2, 2, 3)))]
    if mode == "get":
        op = builtin.BatchedMeshIndexing(items)
        data = data[:1]
    if mode == "set":
        op = builtin.BatchedSetMeshIndexing(items)
    if mode == "inc":
        op = builtin.BatchedIncrMeshIndexing(items)

    @trace(symbolic=True, capture_as_const=True)
    def fwd(*tensors):
        return apply(op, *tensors)[0]

    result = fwd(*data, *tensors)
    check_pygraph_dump(fwd, data + tensors, [result])
def batched_mesh_indexing(input, tuple_val):
    input, tensors, items = unpack_getitem(input, tuple_val)
    op = builtin.BatchedMeshIndexing(items)
    return invoke_op(op, (input, *tensors))