def test_take_along_axis_2d(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 32).float32().reshape((8, 4)) indices = ep.arange(t, len(t)) % t.shape[-1] return ep.take_along_axis(t, indices[..., ep.newaxis], axis=-1)
def test_take_along_axis_3d(dummy: Tensor) -> Tensor: t = ep.arange(dummy, 64).float32().reshape((2, 8, 4)) indices = ep.arange(t, 2 * 8).reshape((2, 8, 1)) % t.shape[-1] return ep.take_along_axis(t, indices, axis=-1)
def test_take_along_axis_2d_first_raises(dummy: Tensor) -> None: t = ep.arange(dummy, 32).float32().reshape((8, 4)) indices = ep.arange(t, t.shape[-1]) % t.shape[0] with pytest.raises(NotImplementedError): ep.take_along_axis(t, indices[ep.newaxis], axis=0)