def test_max_and_argmax_sparse(): n_time = 3 n_batch = 2 n_dim = 5 s0 = np.array([[0,0], [0,1], [1,1], [1,2], [1,2], [2,2], [2,2]], dtype=f32) s1 = np.array([[1,2], [2,3], [1,1], [2,0], [4,1], [3,3], [4,4]], dtype=f32) w = np.array([[1,2], [2,1], [1,2], [3,4], [5,6], [7,8], [9,13]], dtype=f32) m = np.array([[1,1], [1,1], [1,1], [1,1], [1,1], [1,1], [1,0]], dtype=f32) print("W:\n%r" % NativeOp.sparse_to_dense(s0, s1, w, m, n_time, n_dim).eval()) init_out_max = T.zeros((n_time, n_batch), dtype=f32) init_out_arg = T.zeros((n_time, n_batch), dtype=f32) max1, arg1 = NativeOp.max_and_argmax_sparse(s0, s1, w, m, init_out_max, init_out_arg) W = NativeOp.sparse_to_dense(s0, s1, w, m, n_time, n_dim) assert W.ndim == 3 max2, arg2 = T.max_and_argmax(W, axis=2) arg0 = np.array([[2, 2], [4, 1], [4, 3]]) max0 = np.array([[2, 2], [5, 2], [9, 8]]) arg1 = arg1.eval() arg2 = arg2.eval() max1 = max1.eval() max2 = max2.eval() print("arg0:\n%r" % arg0) print("arg1:\n%r" % arg1) print("arg2:\n%r" % arg2) print("max0:\n%r" % max0) print("max1:\n%r" % max1) print("max2:\n%r" % max2) assert_almost_equal(arg0, arg1) assert_almost_equal(arg0, arg2) assert_almost_equal(max0, max1) assert_almost_equal(max0, max2)
def test_max_and_argmax_sparse(): n_time = 3 n_batch = 2 n_dim = 5 s0 = np.array([[0, 0], [0, 1], [1, 1], [1, 2], [1, 2], [2, 2], [2, 2]], dtype=f32) s1 = np.array([[1, 2], [2, 3], [1, 1], [2, 0], [4, 1], [3, 3], [4, 4]], dtype=f32) w = np.array([[1, 2], [2, 1], [1, 2], [3, 4], [5, 6], [7, 8], [9, 13]], dtype=f32) m = np.array([[1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 1], [1, 0]], dtype=f32) print("W:\n%r" % NativeOp.sparse_to_dense(s0, s1, w, m, n_time, n_dim).eval()) init_out_max = T.zeros((n_time, n_batch), dtype=f32) init_out_arg = T.zeros((n_time, n_batch), dtype=f32) max1, arg1 = NativeOp.max_and_argmax_sparse(s0, s1, w, m, init_out_max, init_out_arg) W = NativeOp.sparse_to_dense(s0, s1, w, m, n_time, n_dim) assert W.ndim == 3 max2, arg2 = T.max_and_argmax(W, axis=2) arg0 = np.array([[2, 2], [4, 1], [4, 3]]) max0 = np.array([[2, 2], [5, 2], [9, 8]]) arg1 = arg1.eval() arg2 = arg2.eval() max1 = max1.eval() max2 = max2.eval() print("arg0:\n%r" % arg0) print("arg1:\n%r" % arg1) print("arg2:\n%r" % arg2) print("max0:\n%r" % max0) print("max1:\n%r" % max1) print("max2:\n%r" % max2) assert_almost_equal(arg0, arg1) assert_almost_equal(arg0, arg2) assert_almost_equal(max0, max1) assert_almost_equal(max0, max2)