示例#1
0
def test_take_along_axis_2d(dummy: Tensor) -> Tensor:
    t = ep.arange(dummy, 32).float32().reshape((8, 4))
    indices = ep.arange(t, len(t)) % t.shape[-1]
    return ep.take_along_axis(t, indices[..., ep.newaxis], axis=-1)
示例#2
0
def test_take_along_axis_3d(dummy: Tensor) -> Tensor:
    t = ep.arange(dummy, 64).float32().reshape((2, 8, 4))
    indices = ep.arange(t, 2 * 8).reshape((2, 8, 1)) % t.shape[-1]
    return ep.take_along_axis(t, indices, axis=-1)
示例#3
0
def test_take_along_axis_2d_first_raises(dummy: Tensor) -> None:
    t = ep.arange(dummy, 32).float32().reshape((8, 4))
    indices = ep.arange(t, t.shape[-1]) % t.shape[0]
    with pytest.raises(NotImplementedError):
        ep.take_along_axis(t, indices[ep.newaxis], axis=0)