def test_replace_max(self): a_base = np.array([[10, 30, 20], [60, 40, 50]]) for axis in list(range(a_base.ndim)) + [None]: # we mutate this in the loop a = a_base.copy() # replace the max with a small value i_max = _add_keepdims(np.argmax)(a, axis=axis) put_along_axis(a, i_max, -99, axis=axis) # find the new minimum, which should max i_min = _add_keepdims(np.argmin)(a, axis=axis) assert_equal(i_min, i_max)
def test_broadcast(self): """ Test that non-indexing dimensions are broadcast in both directions """ a = np.ones((3, 4, 1)) ai = np.arange(10, dtype=np.intp).reshape((1, 2, 5)) % 4 put_along_axis(a, ai, 20, axis=1) assert_equal(take_along_axis(a, ai, axis=1), 20)