def cholesky(x): def forward_code(np, data): a = data["inputs"][0] L = data["outputs"][0] tL = np.linalg.cholesky(a) np.copyto(L, tL) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] f_out = data["f_outputs"][0] solve_trans = lambda a, b: np.linalg.solve(T(a), b) phi = lambda X: np.tril(X) / (1. + np.eye(X.shape[-1])) def conjugate_solve(L, X): return solve_trans(L, T(solve_trans(L, T(X)))) s = conjugate_solve(f_out, phi(np.einsum('...ki,...kj->...ij', f_out, dout))) s = (s + T(s)) / 2. np.copyto(out, s) lL = jt.numpy_code( [x.shape], [x.dtype], [x], forward_code, [backward_code], ) L = lL[0] return L
def inv(x): r""" calculate the inverse of x. :param x (...,M,M): :return:x^-1 (...,M,M). """ def forward_code(np, data): a = data["inputs"][0] m_a = data["outputs"][0] t_a = np.linalg.inv(a) np.copyto(m_a, t_a) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] lmx = data["f_outputs"] mx = lmx[0] t = -_dot(_dot(T(mx), dout), T(mx)) np.copyto(out, t) lmx = jt.numpy_code( [x.shape], [x.dtype], [x], forward_code, [backward_code], ) mx = lmx[0] return mx
def solve(a, b): def forward_code(np, data): a, b = data["inputs"] L = data["outputs"][0] ans = np.linalg.solve(a, b) np.copyto(L, ans) def backward_code1(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] f_out = data["f_outputs"][0] inp = data["inputs"][0] updim = lambda x: x if x.ndim == a.ndim else x[..., None] t = -_dot(updim(np.linalg.solve(T(inp), dout)), T(updim(f_out))) np.copyto(out, t) def backward_code2(np, data): out = data["outputs"][0] np.copyto(out, 0) l_ans = jt.numpy_code( [b.shape], [b.dtype], [a, b], forward_code, [backward_code1, backward_code2], ) ans = l_ans[0] return ans
def pinv(x): def forward_code(np, data): a = data["inputs"][0] m_a = data["outputs"][0] t_a = np.linalg.pinv(a) np.copyto(m_a, t_a) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] inp = data["inputs"][0] lmx = data["f_outputs"] mx = lmx[0] t = T( -_dot(_dot(mx, T(dout)), mx) + _dot(_dot(_dot(mx, T(mx)), dout), np.eye(inp.shape[-2]) - _dot(inp, mx)) + _dot(_dot(_dot(np.eye(mx.shape[-2]) - _dot(mx, inp), dout), T(mx)), mx)) np.copyto(out, t) lmx = jt.numpy_code( [x.shape], [x.dtype], [x], forward_code, [backward_code], ) mx = lmx[0] return mx
def det(x): def forward_code(np, data): a = data["inputs"][0] L = data["outputs"][0] tL = np.linalg.det(a) np.copyto(L, tL) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] f_out = data["f_outputs"][0] inp = data["inputs"][0] n_d = np.reshape(dout, np.shape(dout) + (1, 1)) n_o = np.reshape(f_out, np.shape(f_out) + (1, 1)) s = n_d * n_o * T(np.linalg.inv(inp)) np.copyto(out, s) s = x.shape x_s = s[:-2] if len(s) == 2: x_s.append(1) l_det = jt.numpy_code( [x_s], [x.dtype], [x], forward_code, [backward_code], ) det = l_det[0] return det
def execute(self, a): self.save_vars = a return jt.numpy_code( a.shape, a.dtype, [a], self.forward_code, )
def grad(self, grad_a): a = self.save_vars return jt.numpy_code( a.shape, a.dtype, [a, grad_a], self.backward_code, )
def execute(self, a, dim): self.save_vars = a self.dim = dim self.res = jt.numpy_code( a.shape, a.dtype, [a], self.forward_code, ) return self.res
def cumsum(x, dim=None): ''' Parameters: ----------- x: [batch_size, N], jt.var Returns: -------- the cumulative sum of x ''' return jt.numpy_code(x.shape, x.dtype, [x], cumsum_forward, [cumsum_backward])
def check(): a = jt.random((5, 1)) b = jt.numpy_code( a.shape, a.dtype, [a], forward_code, [backward_code], ) assert numpy.allclose(b.data, (a + a).data) da = jt.grad(b, a) one = numpy.ones(a.shape) assert numpy.allclose(da.data, one * 2.0)
def qr(x): r""" do the qr factorization of x in the below formula: x = QR where Q is orthogonal matrix and R is upper-triangle matrix. :param x (...,M,M): :return:q,r as the result of qr factorization.They are both in the shape of (...,M,M). """ def forward_code(np, data): a = data["inputs"][0] q, r = data["outputs"] Q, R = np.linalg.qr(a) np.copyto(q, Q) np.copyto(r, R) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') _harmard = partial(np.einsum, '...ij,...ij->...ij') dout = data["dout"] out = data["outputs"][0] q, r = data["f_outputs"] out_index = data["out_index"] #pl = np.tril(np.ones((inp.shape[-1],inp.shape[-1])))-diags if out_index == 0: # Q_TERM q_t = _dot(T(q), dout) rhs_solve = q_t - T(q_t) rhs_solve = T(np.tril(rhs_solve, -1)) qsolve = np.linalg.solve(r, rhs_solve) qsolve = T(qsolve) tq = _dot(q, qsolve) np.copyto(out, tq) else: #R_TERM r_t = _dot(r, T(dout)) rhs_solve = r_t - T(r_t) rhs_solve = np.tril(rhs_solve, -1) rhs_solve = T(rhs_solve) r_solve = np.linalg.solve(r, rhs_solve) tr = _dot(q, (T(r_solve) + dout)) np.copyto(out, tr) q, r = jt.numpy_code( [x.shape, x.shape], [x.dtype, x.dtype], [x], forward_code, [backward_code], ) return q, r
def eigh(x): r""" calculate the eigenvalues and eigenvectors of x. :param x (...,M,M): :return:w, v. w (...,M) : the eigenvalues. v (...,M,M) : normalized eigenvectors. """ def forward_code(np, data): a = data["inputs"][0] w, v = data["outputs"] tw, tv = np.linalg.eigh(a, UPLO='L') np.copyto(w, tw) np.copyto(v, tv) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] inp = data["inputs"][0] out_index = data["out_index"] w, v = data["f_outputs"] k = int(inp.shape[-1]) w_repeated = np.repeat(w[..., np.newaxis], k, axis=-1) if out_index == 0: t = _dot(v * dout[..., np.newaxis, :], T(v)) np.copyto(out, t) elif out_index == 1: if np.any(dout): off_diag = np.ones((k, k)) - np.eye(k) F = off_diag / (T(w_repeated) - w_repeated + np.eye(k)) t = _dot(_dot(v, F * _dot(T(v), dout)), T(v)) np.copyto(out, t) sw = x.shape[:-2] + x.shape[-1:] sv = x.shape w, v = jt.numpy_code( [sw, sv], [x.dtype, x.dtype], [x], forward_code, [backward_code], ) return w, v
def slogdet(x): r""" calculate the sign and log of the determinant of x. :param x (...,M,M): :return sign, x's logdet. sign array decides the sign of determinant and their values can be -1,0,1.Only Real number now.0 means det is 0 and logdet is -inf. logdet in shape (...,1). """ def forward_code(np, data): a = data["inputs"][0] sign, m_a = data["outputs"] sign_, t_a = np.linalg.slogdet(a) np.copyto(m_a, t_a) np.copyto(sign, sign_) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] inp = data["inputs"][0] out_index = data["out_index"] if out_index == 0: np.copyto(out, 0) if out_index == 1: t = np.reshape(dout, np.shape(dout) + (1, 1)) t = t * T(np.linalg.inv(inp)) np.copyto(out, t) s = x.shape det_s = s[:-2] if len(det_s) == 0: det_s.append(1) sign, mx = jt.numpy_code( [det_s, det_s], [x.dtype, x.dtype], [x], forward_code, [backward_code], ) return sign, mx
def cholesky(x): r""" do Cholesky decomposition of x in the form of below formula: x = LL^T x must be a Hermite and positive-definite matrix. L is a lower-triangular matrix. :param x (...,M,M): :return: L (...,M,M). """ def forward_code(np, data): a = data["inputs"][0] L = data["outputs"][0] tL = np.linalg.cholesky(a) np.copyto(L, tL) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] f_out = data["f_outputs"][0] solve_trans = lambda a, b: np.linalg.solve(T(a), b) phi = lambda X: np.tril(X) / (1. + np.eye(X.shape[-1])) def conjugate_solve(L, X): return solve_trans(L, T(solve_trans(L, T(X)))) s = conjugate_solve(f_out, phi(np.einsum('...ki,...kj->...ij', f_out, dout))) s = (s + T(s)) / 2. np.copyto(out, s) lL = jt.numpy_code( [x.shape], [x.dtype], [x], forward_code, [backward_code], ) L = lL[0] return L
def check(): a = jt.random((5, 1)) b = jt.random((5, 1)) c, d = jt.numpy_code( [a.shape, a.shape], [a.dtype, a.dtype], [a, b], forward_code, [backward_code1, backward_code2], ) assert numpy.allclose(c.data, (a + b).data) assert numpy.allclose(d.data, (a - b).data) dca, dcb = jt.grad(c, [a, b]) dda, ddb = jt.grad(d, [a, b]) one = numpy.ones(a.shape) mone = one * -1.0 assert numpy.allclose(dca.data, one) assert numpy.allclose(dcb.data, one) assert numpy.allclose(dda.data, one) assert numpy.allclose(ddb.data, mone)
def test_memory_leak(self): def forward_code(np, data): a, b = data["inputs"] c, d = data["outputs"] np.add(a, b, out=c) np.subtract(a, b, out=d) def backward_code1(np, data): dout = data["dout"] out = data["outputs"][0] np.copyto(out, dout) def backward_code2(np, data): dout = data["dout"] out_index = data["out_index"] out = data["outputs"][0] if out_index == 0: np.copyto(out, dout) else: np.negative(dout, out) for i in range(1000000): a = jt.random((10000, 1)) b = jt.random((10000, 1)) c, d = jt.numpy_code( [a.shape, a.shape], [a.dtype, a.dtype], [a, b], forward_code, [backward_code1, backward_code2], ) assert numpy.allclose(c.data, (a + b).data) assert numpy.allclose(d.data, (a - b).data) dca, dcb = jt.grad(c, [a, b]) dda, ddb = jt.grad(d, [a, b]) one = numpy.ones(a.shape) mone = one * -1.0 assert numpy.allclose(dca.data, one) assert numpy.allclose(dcb.data, one) assert numpy.allclose(dda.data, one) assert numpy.allclose(ddb.data, mone)
def pinv(x): r""" calculate the pseudo-inverse of a x. :param x (...,M,N) :return: x's pinv (...N,M) """ def forward_code(np, data): a = data["inputs"][0] m_a = data["outputs"][0] t_a = np.linalg.pinv(a) np.copyto(m_a, t_a) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] inp = data["inputs"][0] lmx = data["f_outputs"] mx = lmx[0] t = T( -_dot(_dot(mx, T(dout)), mx) + _dot(_dot(_dot(mx, T(mx)), dout), np.eye(inp.shape[-2]) - _dot(inp, mx)) + _dot(_dot(_dot(np.eye(mx.shape[-2]) - _dot(mx, inp), dout), T(mx)), mx)) np.copyto(out, t) sw = list(x.shape[:-2]) + [x.shape[-1]] + [x.shape[-2]] lmx = jt.numpy_code( [sw], [x.dtype], [x], forward_code, [backward_code], ) mx = lmx[0] return mx
def solve(a, b): r""" Solve a linear matrix equation Ax = B.This is done by calculating x = A^-1B.So A must not be singular. :param a:(...,M,M) :param b:(...,M) :return:solution of Ax = b formula.x in the shape of (...M) """ def forward_code(np, data): a, b = data["inputs"] L = data["outputs"][0] ans = np.linalg.solve(a, b) np.copyto(L, ans) def backward_code1(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] f_out = data["f_outputs"][0] inp = data["inputs"][0] updim = lambda x: x if x.ndim == a.ndim else x[..., None] t = -_dot(updim(np.linalg.solve(T(inp), dout)), T(updim(f_out))) np.copyto(out, t) def backward_code2(np, data): out = data["outputs"][0] np.copyto(out, 0) l_ans = jt.numpy_code( [b.shape], [b.dtype], [a, b], forward_code, [backward_code1, backward_code2], ) ans = l_ans[0] return ans
def test(self): def forward_code(np, data): a = data["inputs"][0] b = data["outputs"][0] np.add(a, a, out=b) def backward_code(np, data): dout = data["dout"] out = data["outputs"][0] np.copyto(out, dout * 2.0) a = jt.random((5, 1)) b = jt.numpy_code( a.shape, a.dtype, [a], forward_code, [backward_code], ) assert np.allclose(b.data, (a + a).data) da = jt.grad(b, a) one = np.ones(a.shape) assert np.allclose(da.data, one * 2.0)
def slogdet(x): def forward_code(np, data): a = data["inputs"][0] sign, m_a = data["outputs"] sign_, t_a = np.linalg.slogdet(a) np.copyto(m_a, t_a) np.copyto(sign, sign_) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] inp = data["inputs"][0] out_index = data["out_index"] if out_index == 0: np.copyto(out, 0) if out_index == 1: t = np.reshape(dout, np.shape(dout) + (1, 1)) t = t * T(np.linalg.inv(inp)) np.copyto(out, t) s = x.shape det_s = s[:-2] if len(det_s) == 0: det_s.append(1) sign, mx = jt.numpy_code( [det_s, det_s], [x.dtype, x.dtype], [x], forward_code, [backward_code], ) return sign, mx
def svd(x): r''' calculate the Singular Value Decomposition of x.It follows the below fomula: x = usv* only support full matrices == False ver now, which means: x's shape (...,M,K) u's shape (...,M,K) s's shape (...,K) v's shape (...,K,N) where K is min(M,N). :param x: :return:u,s,v. ''' def forward_code(np, data): a = data["inputs"][0] u, s, v = data["outputs"] #TODO:remove copyto tu, ts, tv = np.linalg.svd(a, full_matrices=0) np.copyto(u, tu) np.copyto(s, ts) np.copyto(v, tv) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] inp = data["inputs"][0] out_index = data["out_index"] u, s, v = data["f_outputs"] v = T(v) m, n = inp.shape[-2:] k = np.min((m, n)) i = np.reshape( np.eye(k), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (k, k)))) if out_index == 0: f = 1 / (s[..., np.newaxis, :]**2 - s[..., :, np.newaxis]**2 + i) gu = dout utgu = _dot(T(u), gu) t = (f * (utgu - T(utgu))) * s[..., np.newaxis, :] t = _dot(_dot(u, t), T(v)) if m > n: i_minus_uut = (np.reshape( np.eye(m), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (m, m)))) - _dot(u, np.conj(T(u)))) t = t + T( _dot(_dot(v / s[..., np.newaxis, :], T(gu)), i_minus_uut)) np.copyto(out, t) elif out_index == 1: gs = dout t = i * gs[..., :, np.newaxis] t = _dot(_dot(u, t), T(v)) np.copyto(out, t) elif out_index == 2: f = 1 / (s[..., np.newaxis, :]**2 - s[..., :, np.newaxis]**2 + i) gv = dout vtgv = _dot(T(v), gv) t = s[..., :, np.newaxis] * (f * (vtgv - T(vtgv))) t = _dot(_dot(u, t), T(v)) if m < n: i_minus_vvt = (np.reshape( np.eye(n), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (n, n)))) - _dot(v, np.conj(T(v)))) t = t + T( _dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt)) np.copyto(out, t) m, n = x.shape[-2:] k = min(m, n) s1 = list(x.shape) s1[-1] = k s2 = list(x.shape) s2[-2] = k s3 = list(x.shape)[:-2] s3.append(k) u, s, v = jt.numpy_code( [s1, s3, s2], [x.dtype, x.dtype, x.dtype], [x], forward_code, [backward_code], ) return u, s, v
def svd(x): def forward_code(np, data): a = data["inputs"][0] u, s, v = data["outputs"] #TODO:remove copyto tu, ts, tv = np.linalg.svd(a, full_matrices=0) np.copyto(u, tu) np.copyto(s, ts) np.copyto(v, tv) def backward_code(np, data): def T(x): return np.swapaxes(x, -1, -2) _dot = partial(np.einsum, '...ij,...jk->...ik') dout = data["dout"] out = data["outputs"][0] inp = data["inputs"][0] out_index = data["out_index"] u, s, v = data["f_outputs"] v = T(v) m, n = inp.shape[-2:] k = np.min((m, n)) i = np.reshape( np.eye(k), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (k, k)))) if out_index == 0: f = 1 / (s[..., np.newaxis, :]**2 - s[..., :, np.newaxis]**2 + i) gu = dout utgu = _dot(T(u), gu) t = (f * (utgu - T(utgu))) * s[..., np.newaxis, :] t = _dot(_dot(u, t), T(v)) if m > n: i_minus_uut = (np.reshape( np.eye(m), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (m, m)))) - _dot(u, np.conj(T(u)))) t = t + T( _dot(_dot(v / s[..., np.newaxis, :], T(gu)), i_minus_uut)) np.copyto(out, t) elif out_index == 1: gs = dout t = i * gs[..., :, np.newaxis] t = _dot(_dot(u, t), T(v)) np.copyto(out, t) elif out_index == 2: f = 1 / (s[..., np.newaxis, :]**2 - s[..., :, np.newaxis]**2 + i) gv = dout vtgv = _dot(T(v), gv) t = s[..., :, np.newaxis] * (f * (vtgv - T(vtgv))) t = _dot(_dot(u, t), T(v)) if m < n: i_minus_vvt = (np.reshape( np.eye(n), np.concatenate((np.ones(inp.ndim - 2, dtype=int), (n, n)))) - _dot(v, np.conj(T(v)))) t = t + T( _dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt)) np.copyto(out, t) m, n = x.shape[-2:] k = min(m, n) s1 = list(x.shape) s1[-1] = k s2 = list(x.shape) s2[-2] = k s3 = list(x.shape)[:-2] s3.append(k) u, s, v = jt.numpy_code( [s1, s3, s2], [x.dtype, x.dtype, x.dtype], [x], forward_code, [backward_code], ) return u, s, v