def test_topk_backward(device_id, precision): def check_grad_last_axis(input, root, indices, output): d = input.shape[-1] k = indices.shape[-1] expected_output = np.zeros_like(input).reshape(-1, d) ind = np.reshape(indices, (-1, k)) r = np.reshape(root, (-1, k)) assert ind.shape[0] == r.shape[0] == expected_output.shape[0] for i in range(expected_output.shape[0]): for j in range(k): expected_output[i, int(ind[i, j])] = r[i, j] expected_output = expected_output.reshape(input.shape) assert np.allclose(output, expected_output) dt = PRECISION_TO_TYPE[precision] dev = cntk_device(device_id) axis = -1 h = C.placeholder() p = C.parameter((4, 5, 6)) p.value = p.value + np.random.randn(*p.shape) y = C.top_k(h, 3, axis=axis) y.replace_placeholder(p) dy, top = y.forward({}, y.outputs, set([y.outputs[0]])) indices = top[y.outputs[1]] root = np.ones_like(indices) root = root + np.arange(np.prod(root.shape)).reshape(*root.shape) cg = y.backward(dy, {y.outputs[0]: root}, set([p]))[p] check_grad_last_axis(p.value, root, indices, cg) q = C.sequence.input_variable((5, 6), needs_gradient=True) q0 = [np.random.randn(4 - i, 5, 6).astype(dt) for i in range(2)] y = C.top_k(q, 3, axis=axis) dy, top = y.forward({q: q0}, y.outputs, set([y.outputs[0]]), device=dev) indices = top[y.outputs[1]] root = [ np.ones_like(i) + 100 * k + np.arange(np.prod(i.shape)).reshape(*i.shape) for k, i in enumerate(indices) ] cg = y.backward(dy, {y.outputs[0]: root}, set([q]))[q] for i in range(2): check_grad_last_axis(q0[i], root[i], indices[i], cg[i])
def test_topk(axis, device_id, precision): def sliceit(x, axis): if axis not in (-2, -1): raise ValueError("unknown axis %d" % axis) if axis == -1: return x[..., -1:-4:-1] elif axis == -2: return x[..., -1:-4:-1, :] def check_topk_values_and_indices(top, y, x): vals = top[y.outputs[0]] idxs = top[y.outputs[1]] for vi, xi in zip(vals, x): assert np.allclose(vi, sliceit(np.sort(xi, axis=axis), axis)) for idxi, xi in zip(idxs, x): assert np.allclose(idxi, sliceit(np.argsort(xi, axis=axis), axis)) dt = PRECISION_TO_TYPE[precision] dev = cntk_device(device_id) p = C.parameter((10, 20, 30), dtype=dt) np.random.seed(90210) p.value = p.value + np.random.randn(*p.shape) y = C.top_k(p, 3, axis=axis) top = y.eval({}) # for now run this on the device where the parameter is assert np.allclose(top[y.outputs[0]], sliceit(np.sort(p.value, axis=axis), axis)) assert np.allclose(top[y.outputs[1]], sliceit(np.argsort(p.value, axis=axis), axis)) q = C.input_variable((5, 6), dtype=dt) q0 = np.random.randn(2, 5, 6).astype(dt) y = C.top_k(q, 3, axis=axis) top = y.eval({q: q0}, device=dev) check_topk_values_and_indices(top, y, q0) q = C.sequence.input_variable((5, 6), dtype=dt) q0 = [np.random.randn(4 - i, 5, 6).astype(dt) for i in range(2)] y = C.top_k(q, 3, axis=axis) top = y.eval({q: q0}, device=dev) check_topk_values_and_indices(top, y, q0)
def test_topk(axis, device_id, precision): def sliceit(x, axis): if axis not in (-2, -1): raise ValueError("unknown axis %d"%axis) if axis == -1: return x[..., -1:-4:-1] elif axis == -2: return x[..., -1:-4:-1, :] def check_topk_values_and_indices(top, y, x): vals = top[y.outputs[0]] idxs = top[y.outputs[1]] for vi,xi in zip(vals, x): assert np.allclose(vi, sliceit(np.sort(xi, axis=axis), axis)) for idxi,xi in zip(idxs, x): assert np.allclose(idxi, sliceit(np.argsort(xi, axis=axis), axis)) dt = PRECISION_TO_TYPE[precision] dev = cntk_device(device_id) p = C.parameter((10, 20, 30), dtype=dt) np.random.seed(90210) p.value = p.value + np.random.randn(*p.shape) y = C.top_k(p, 3, axis=axis) top = y.eval({}) # for now run this on the device where the parameter is assert np.allclose(top[y.outputs[0]], sliceit(np.sort(p.value, axis=axis), axis)) assert np.allclose(top[y.outputs[1]], sliceit(np.argsort(p.value, axis=axis), axis)) q = C.input_variable((5, 6), dtype=dt) q0 = np.random.randn(2, 5, 6).astype(dt) y = C.top_k(q, 3, axis=axis) top = y.eval({q:q0}, device=dev) check_topk_values_and_indices(top, y, q0) q = C.sequence.input_variable((5, 6), dtype=dt) q0 = [np.random.randn(4-i, 5, 6).astype(dt) for i in range(2)] y = C.top_k(q, 3, axis=axis) top = y.eval({q:q0}, device=dev) check_topk_values_and_indices(top, y, q0)
def test_topk_backward(device_id, precision): def check_grad_last_axis(input, root, indices, output): d = input.shape[-1] k = indices.shape[-1] expected_output = np.zeros_like(input).reshape(-1,d) ind = np.reshape(indices, (-1,k)) r = np.reshape(root,(-1,k)) assert ind.shape[0] == r.shape[0] == expected_output.shape[0] for i in range(expected_output.shape[0]): for j in range(k): expected_output[i,int(ind[i,j])] = r[i,j] expected_output = expected_output.reshape(input.shape) assert np.allclose(output, expected_output) dt = PRECISION_TO_TYPE[precision] dev = cntk_device(device_id) axis=-1 h = C.placeholder() p = C.parameter((4, 5, 6)) p.value = p.value + np.random.randn(*p.shape) y = C.top_k(h, 3, axis=axis) y.replace_placeholder(p) dy, top = y.forward({}, y.outputs, set([y.outputs[0]])) indices = top[y.outputs[1]] root = np.ones_like(indices) root = root + np.arange(np.prod(root.shape)).reshape(*root.shape) cg = y.backward(dy, {y.outputs[0]:root}, set([p]))[p] check_grad_last_axis(p.value, root, indices, cg) q = C.sequence.input_variable((5,6), needs_gradient=True) q0 = [np.random.randn(4-i,5,6).astype(dt) for i in range(2)] y = C.top_k(q, 3, axis=axis) dy, top = y.forward({q:q0}, y.outputs, set([y.outputs[0]]), device=dev) indices = top[y.outputs[1]] root = [np.ones_like(i) + 100 * k + np.arange(np.prod(i.shape)).reshape(*i.shape) for k,i in enumerate(indices)] cg = y.backward(dy, {y.outputs[0]:root}, set([q]))[q] for i in range(2): check_grad_last_axis(q0[i], root[i], indices[i], cg[i])
def inner(a): # a: [#, *] [static_axes, num_classes] k_values, k_indices = C.top_k(a, k=k, axis=axis).outputs # k_indices [#, *] [static_axes, k] b = C.one_hot(k_indices, num_classes) # b: [#, *] [static_axes, k, num_classes] valid_probabilities = C.squeeze(C.reduce_sum(b, axis=-2), axes=(-2, )) # valid_probabilities: [#, *] [static_axes, num_classes] # k largest probabilies are retained, everything else is set to -inf and will not be sampled minus_inf = C.constant(-1e+30) d = a * valid_probabilities e = C.element_select(d, d, minus_inf) # e: [#, *] [static_axes, num_classes] # sample from top_k distribution once s = sample(e, axis=axis, name=name) # s: [#, *] [static_axes, num_classes] return s