Esempio n. 1
0
def test_mesh_indexing(mode):
    items = [[0, True, True, True, False], [1, False, False, False, True]]
    tensors = [Tensor(0), Tensor(5), Tensor(2), Tensor([1, 3])]
    data = [Tensor(np.random.random((5, 5))), Tensor(np.random.random((3, 2)))]
    if mode == "get":
        op = builtin.IndexingMultiAxisVec(items)
        data = data[:1]
    if mode == "set":
        op = builtin.IndexingSetMultiAxisVec(items)
    if mode == "inc":
        op = builtin.IndexingIncrMultiAxisVec(items)

    @trace(symbolic=True, capture_as_const=True)
    def fwd(*tensors):
        return apply(op, *tensors)[0]

    result = fwd(*data, *tensors)
    check_pygraph_dump(fwd, data + tensors, [result])
def advance_indexing(input, tuple_val):
    input, tensors, items = unpack_getitem(input, tuple_val)
    op = builtin.IndexingMultiAxisVec(items)
    return invoke_op(op, (input, *tensors))