def test_multidimensional_indexing_along_axis(self, t): """Test that indexing with a sequence properly extracts the elements from the specified tensor axis""" indices = np.array([[0, 0], [1, 0]]) res = fn.take(t, indices, axis=1) expected = np.array([[[[1, 2], [1, 2]], [[3, 4], [1, 2]]], [[[5, 6], [5, 6]], [[0, -1], [5, 6]]]]) assert fn.allclose(res, expected)
def test_array_indexing_along_axis(self, t): """Test that indexing with a sequence properly extracts the elements from the specified tensor axis""" indices = [0, 1, -2] res = fn.take(t, indices, axis=2) expected = np.array([[[1, 2, 1], [3, 4, 3], [-1, 1, -1]], [[5, 6, 5], [0, -1, 0], [2, 1, 2]]]) assert fn.allclose(res, expected)
def test_array_indexing(self, t): """Test that indexing with a sequence properly extracts the elements from the flattened tensor""" indices = [0, 2, 3, 6, -2] res = fn.take(t, indices) assert fn.allclose(res, [1, 3, 4, 5, 2])
def test_multidimensional_indexing(self, t): """Test that indexing with a multi-dimensional sequence properly extracts the elements from the flattened tensor""" indices = [[0, 1], [3, 2]] res = fn.take(t, indices) assert fn.allclose(res, [[1, 2], [4, 3]])
def test_flattened_indexing(self, t): """Test that indexing without the axis argument will flatten the tensor first""" indices = 5 res = fn.take(t, indices) assert fn.allclose(res, 1)