Ejemplo n.º 1
0
def cholesky(x):
    def forward_code(np, data):
        a = data["inputs"][0]
        L = data["outputs"][0]
        tL = np.linalg.cholesky(a)
        np.copyto(L, tL)

    def backward_code(np, data):
        def T(x):
            return np.swapaxes(x, -1, -2)

        _dot = partial(np.einsum, '...ij,...jk->...ik')
        dout = data["dout"]
        out = data["outputs"][0]
        f_out = data["f_outputs"][0]
        solve_trans = lambda a, b: np.linalg.solve(T(a), b)
        phi = lambda X: np.tril(X) / (1. + np.eye(X.shape[-1]))

        def conjugate_solve(L, X):
            return solve_trans(L, T(solve_trans(L, T(X))))

        s = conjugate_solve(f_out,
                            phi(np.einsum('...ki,...kj->...ij', f_out, dout)))
        s = (s + T(s)) / 2.
        np.copyto(out, s)

    lL = jt.numpy_code(
        [x.shape],
        [x.dtype],
        [x],
        forward_code,
        [backward_code],
    )
    L = lL[0]
    return L
Ejemplo n.º 2
0
def inv(x):
    r"""
    calculate the inverse of x.
    :param x (...,M,M):
    :return:x^-1 (...,M,M).
    """
    def forward_code(np, data):
        a = data["inputs"][0]
        m_a = data["outputs"][0]
        t_a = np.linalg.inv(a)
        np.copyto(m_a, t_a)

    def backward_code(np, data):
        def T(x):
            return np.swapaxes(x, -1, -2)

        _dot = partial(np.einsum, '...ij,...jk->...ik')
        dout = data["dout"]
        out = data["outputs"][0]
        lmx = data["f_outputs"]
        mx = lmx[0]
        t = -_dot(_dot(T(mx), dout), T(mx))
        np.copyto(out, t)

    lmx = jt.numpy_code(
        [x.shape],
        [x.dtype],
        [x],
        forward_code,
        [backward_code],
    )
    mx = lmx[0]
    return mx
Ejemplo n.º 3
0
def solve(a, b):
    def forward_code(np, data):
        a, b = data["inputs"]
        L = data["outputs"][0]
        ans = np.linalg.solve(a, b)
        np.copyto(L, ans)

    def backward_code1(np, data):
        def T(x):
            return np.swapaxes(x, -1, -2)

        _dot = partial(np.einsum, '...ij,...jk->...ik')
        dout = data["dout"]
        out = data["outputs"][0]
        f_out = data["f_outputs"][0]
        inp = data["inputs"][0]
        updim = lambda x: x if x.ndim == a.ndim else x[..., None]
        t = -_dot(updim(np.linalg.solve(T(inp), dout)), T(updim(f_out)))
        np.copyto(out, t)

    def backward_code2(np, data):
        out = data["outputs"][0]
        np.copyto(out, 0)

    l_ans = jt.numpy_code(
        [b.shape],
        [b.dtype],
        [a, b],
        forward_code,
        [backward_code1, backward_code2],
    )
    ans = l_ans[0]
    return ans
Ejemplo n.º 4
0
def pinv(x):
    def forward_code(np, data):
        a = data["inputs"][0]
        m_a = data["outputs"][0]
        t_a = np.linalg.pinv(a)
        np.copyto(m_a, t_a)

    def backward_code(np, data):
        def T(x):
            return np.swapaxes(x, -1, -2)

        _dot = partial(np.einsum, '...ij,...jk->...ik')
        dout = data["dout"]
        out = data["outputs"][0]
        inp = data["inputs"][0]
        lmx = data["f_outputs"]
        mx = lmx[0]
        t = T(
            -_dot(_dot(mx, T(dout)), mx) +
            _dot(_dot(_dot(mx, T(mx)), dout),
                 np.eye(inp.shape[-2]) - _dot(inp, mx)) +
            _dot(_dot(_dot(np.eye(mx.shape[-2]) -
                           _dot(mx, inp), dout), T(mx)), mx))
        np.copyto(out, t)

    lmx = jt.numpy_code(
        [x.shape],
        [x.dtype],
        [x],
        forward_code,
        [backward_code],
    )
    mx = lmx[0]
    return mx
Ejemplo n.º 5
0
def det(x):
    def forward_code(np, data):
        a = data["inputs"][0]
        L = data["outputs"][0]
        tL = np.linalg.det(a)
        np.copyto(L, tL)

    def backward_code(np, data):
        def T(x):
            return np.swapaxes(x, -1, -2)

        _dot = partial(np.einsum, '...ij,...jk->...ik')
        dout = data["dout"]
        out = data["outputs"][0]
        f_out = data["f_outputs"][0]
        inp = data["inputs"][0]
        n_d = np.reshape(dout, np.shape(dout) + (1, 1))
        n_o = np.reshape(f_out, np.shape(f_out) + (1, 1))
        s = n_d * n_o * T(np.linalg.inv(inp))
        np.copyto(out, s)

    s = x.shape
    x_s = s[:-2]
    if len(s) == 2:
        x_s.append(1)
    l_det = jt.numpy_code(
        [x_s],
        [x.dtype],
        [x],
        forward_code,
        [backward_code],
    )
    det = l_det[0]
    return det
Ejemplo n.º 6
0
 def execute(self, a):
     self.save_vars = a
     return jt.numpy_code(
         a.shape,
         a.dtype,
         [a],
         self.forward_code,
     )
Ejemplo n.º 7
0
 def grad(self, grad_a):
     a = self.save_vars
     return jt.numpy_code(
         a.shape,
         a.dtype,
         [a, grad_a],
         self.backward_code,
     )
Ejemplo n.º 8
0
 def execute(self, a, dim):
     self.save_vars = a
     self.dim = dim
     self.res = jt.numpy_code(
         a.shape,
         a.dtype,
         [a],
         self.forward_code,
     )
     return self.res
Ejemplo n.º 9
0
def cumsum(x, dim=None):
    '''
    Parameters:
    -----------
    x: [batch_size, N], jt.var

    Returns:
    --------
    the cumulative sum of x
    '''
    return jt.numpy_code(x.shape, x.dtype, [x], cumsum_forward, [cumsum_backward])
Ejemplo n.º 10
0
 def check():
     a = jt.random((5, 1))
     b = jt.numpy_code(
         a.shape,
         a.dtype,
         [a],
         forward_code,
         [backward_code],
     )
     assert numpy.allclose(b.data, (a + a).data)
     da = jt.grad(b, a)
     one = numpy.ones(a.shape)
     assert numpy.allclose(da.data, one * 2.0)
Ejemplo n.º 11
0
def qr(x):
    r"""
    do the qr factorization of x in the below formula:
    x = QR where Q is orthogonal matrix and R is upper-triangle matrix.
    :param x (...,M,M):
    :return:q,r as the result of qr factorization.They are both in the shape of (...,M,M).
    """
    def forward_code(np, data):
        a = data["inputs"][0]
        q, r = data["outputs"]
        Q, R = np.linalg.qr(a)
        np.copyto(q, Q)
        np.copyto(r, R)

    def backward_code(np, data):
        def T(x):
            return np.swapaxes(x, -1, -2)

        _dot = partial(np.einsum, '...ij,...jk->...ik')
        _harmard = partial(np.einsum, '...ij,...ij->...ij')
        dout = data["dout"]
        out = data["outputs"][0]
        q, r = data["f_outputs"]
        out_index = data["out_index"]
        #pl = np.tril(np.ones((inp.shape[-1],inp.shape[-1])))-diags
        if out_index == 0:  # Q_TERM
            q_t = _dot(T(q), dout)
            rhs_solve = q_t - T(q_t)
            rhs_solve = T(np.tril(rhs_solve, -1))
            qsolve = np.linalg.solve(r, rhs_solve)
            qsolve = T(qsolve)
            tq = _dot(q, qsolve)
            np.copyto(out, tq)
        else:  #R_TERM
            r_t = _dot(r, T(dout))
            rhs_solve = r_t - T(r_t)
            rhs_solve = np.tril(rhs_solve, -1)
            rhs_solve = T(rhs_solve)
            r_solve = np.linalg.solve(r, rhs_solve)
            tr = _dot(q, (T(r_solve) + dout))
            np.copyto(out, tr)

    q, r = jt.numpy_code(
        [x.shape, x.shape],
        [x.dtype, x.dtype],
        [x],
        forward_code,
        [backward_code],
    )
    return q, r
Ejemplo n.º 12
0
def eigh(x):
    r"""
    calculate the eigenvalues and eigenvectors of x.
    :param x (...,M,M):
    :return:w, v.
    w (...,M) : the eigenvalues.
    v (...,M,M) : normalized eigenvectors.
    """
    def forward_code(np, data):
        a = data["inputs"][0]
        w, v = data["outputs"]
        tw, tv = np.linalg.eigh(a, UPLO='L')
        np.copyto(w, tw)
        np.copyto(v, tv)

    def backward_code(np, data):
        def T(x):
            return np.swapaxes(x, -1, -2)

        _dot = partial(np.einsum, '...ij,...jk->...ik')
        dout = data["dout"]
        out = data["outputs"][0]
        inp = data["inputs"][0]
        out_index = data["out_index"]
        w, v = data["f_outputs"]
        k = int(inp.shape[-1])
        w_repeated = np.repeat(w[..., np.newaxis], k, axis=-1)
        if out_index == 0:
            t = _dot(v * dout[..., np.newaxis, :], T(v))
            np.copyto(out, t)
        elif out_index == 1:
            if np.any(dout):
                off_diag = np.ones((k, k)) - np.eye(k)
                F = off_diag / (T(w_repeated) - w_repeated + np.eye(k))
                t = _dot(_dot(v, F * _dot(T(v), dout)), T(v))
                np.copyto(out, t)

    sw = x.shape[:-2] + x.shape[-1:]
    sv = x.shape
    w, v = jt.numpy_code(
        [sw, sv],
        [x.dtype, x.dtype],
        [x],
        forward_code,
        [backward_code],
    )
    return w, v
Ejemplo n.º 13
0
def slogdet(x):
    r"""
    calculate the sign and log of the determinant of x.
    :param x (...,M,M):
    :return sign, x's logdet.
    sign array decides the sign of determinant and their values can be -1,0,1.Only Real number now.0 means det is 0 and logdet is -inf.
    logdet in shape (...,1).
    """
    def forward_code(np, data):
        a = data["inputs"][0]
        sign, m_a = data["outputs"]
        sign_, t_a = np.linalg.slogdet(a)
        np.copyto(m_a, t_a)
        np.copyto(sign, sign_)

    def backward_code(np, data):
        def T(x):
            return np.swapaxes(x, -1, -2)

        _dot = partial(np.einsum, '...ij,...jk->...ik')
        dout = data["dout"]
        out = data["outputs"][0]
        inp = data["inputs"][0]
        out_index = data["out_index"]
        if out_index == 0:
            np.copyto(out, 0)
        if out_index == 1:
            t = np.reshape(dout, np.shape(dout) + (1, 1))
            t = t * T(np.linalg.inv(inp))
            np.copyto(out, t)

    s = x.shape
    det_s = s[:-2]
    if len(det_s) == 0:
        det_s.append(1)
    sign, mx = jt.numpy_code(
        [det_s, det_s],
        [x.dtype, x.dtype],
        [x],
        forward_code,
        [backward_code],
    )
    return sign, mx
Ejemplo n.º 14
0
def cholesky(x):
    r"""
    do Cholesky decomposition of x in the form of below formula:
    x = LL^T
    x must be a Hermite and positive-definite matrix. L is a lower-triangular matrix.
    :param x (...,M,M):
    :return: L (...,M,M).
    """
    def forward_code(np, data):
        a = data["inputs"][0]
        L = data["outputs"][0]
        tL = np.linalg.cholesky(a)
        np.copyto(L, tL)

    def backward_code(np, data):
        def T(x):
            return np.swapaxes(x, -1, -2)

        _dot = partial(np.einsum, '...ij,...jk->...ik')
        dout = data["dout"]
        out = data["outputs"][0]
        f_out = data["f_outputs"][0]
        solve_trans = lambda a, b: np.linalg.solve(T(a), b)
        phi = lambda X: np.tril(X) / (1. + np.eye(X.shape[-1]))

        def conjugate_solve(L, X):
            return solve_trans(L, T(solve_trans(L, T(X))))

        s = conjugate_solve(f_out,
                            phi(np.einsum('...ki,...kj->...ij', f_out, dout)))
        s = (s + T(s)) / 2.
        np.copyto(out, s)

    lL = jt.numpy_code(
        [x.shape],
        [x.dtype],
        [x],
        forward_code,
        [backward_code],
    )
    L = lL[0]
    return L
Ejemplo n.º 15
0
 def check():
     a = jt.random((5, 1))
     b = jt.random((5, 1))
     c, d = jt.numpy_code(
         [a.shape, a.shape],
         [a.dtype, a.dtype],
         [a, b],
         forward_code,
         [backward_code1, backward_code2],
     )
     assert numpy.allclose(c.data, (a + b).data)
     assert numpy.allclose(d.data, (a - b).data)
     dca, dcb = jt.grad(c, [a, b])
     dda, ddb = jt.grad(d, [a, b])
     one = numpy.ones(a.shape)
     mone = one * -1.0
     assert numpy.allclose(dca.data, one)
     assert numpy.allclose(dcb.data, one)
     assert numpy.allclose(dda.data, one)
     assert numpy.allclose(ddb.data, mone)
Ejemplo n.º 16
0
    def test_memory_leak(self):
        def forward_code(np, data):
            a, b = data["inputs"]
            c, d = data["outputs"]
            np.add(a, b, out=c)
            np.subtract(a, b, out=d)

        def backward_code1(np, data):
            dout = data["dout"]
            out = data["outputs"][0]
            np.copyto(out, dout)

        def backward_code2(np, data):
            dout = data["dout"]
            out_index = data["out_index"]
            out = data["outputs"][0]
            if out_index == 0:
                np.copyto(out, dout)
            else:
                np.negative(dout, out)

        for i in range(1000000):
            a = jt.random((10000, 1))
            b = jt.random((10000, 1))
            c, d = jt.numpy_code(
                [a.shape, a.shape],
                [a.dtype, a.dtype],
                [a, b],
                forward_code,
                [backward_code1, backward_code2],
            )
            assert numpy.allclose(c.data, (a + b).data)
            assert numpy.allclose(d.data, (a - b).data)
            dca, dcb = jt.grad(c, [a, b])
            dda, ddb = jt.grad(d, [a, b])
            one = numpy.ones(a.shape)
            mone = one * -1.0
            assert numpy.allclose(dca.data, one)
            assert numpy.allclose(dcb.data, one)
            assert numpy.allclose(dda.data, one)
            assert numpy.allclose(ddb.data, mone)
Ejemplo n.º 17
0
def pinv(x):
    r"""
    calculate the pseudo-inverse of a x.
    :param x (...,M,N)
    :return: x's pinv (...N,M)
    """
    def forward_code(np, data):
        a = data["inputs"][0]
        m_a = data["outputs"][0]
        t_a = np.linalg.pinv(a)
        np.copyto(m_a, t_a)

    def backward_code(np, data):
        def T(x):
            return np.swapaxes(x, -1, -2)

        _dot = partial(np.einsum, '...ij,...jk->...ik')
        dout = data["dout"]
        out = data["outputs"][0]
        inp = data["inputs"][0]
        lmx = data["f_outputs"]
        mx = lmx[0]
        t = T(
            -_dot(_dot(mx, T(dout)), mx) +
            _dot(_dot(_dot(mx, T(mx)), dout),
                 np.eye(inp.shape[-2]) - _dot(inp, mx)) +
            _dot(_dot(_dot(np.eye(mx.shape[-2]) -
                           _dot(mx, inp), dout), T(mx)), mx))
        np.copyto(out, t)

    sw = list(x.shape[:-2]) + [x.shape[-1]] + [x.shape[-2]]
    lmx = jt.numpy_code(
        [sw],
        [x.dtype],
        [x],
        forward_code,
        [backward_code],
    )
    mx = lmx[0]
    return mx
Ejemplo n.º 18
0
def solve(a, b):
    r"""
    Solve a linear matrix equation Ax = B.This is done by calculating x = A^-1B.So A must not be singular.
    :param a:(...,M,M)
    :param b:(...,M)
    :return:solution of Ax = b formula.x in the shape of (...M)
    """
    def forward_code(np, data):
        a, b = data["inputs"]
        L = data["outputs"][0]
        ans = np.linalg.solve(a, b)
        np.copyto(L, ans)

    def backward_code1(np, data):
        def T(x):
            return np.swapaxes(x, -1, -2)

        _dot = partial(np.einsum, '...ij,...jk->...ik')
        dout = data["dout"]
        out = data["outputs"][0]
        f_out = data["f_outputs"][0]
        inp = data["inputs"][0]
        updim = lambda x: x if x.ndim == a.ndim else x[..., None]
        t = -_dot(updim(np.linalg.solve(T(inp), dout)), T(updim(f_out)))
        np.copyto(out, t)

    def backward_code2(np, data):
        out = data["outputs"][0]
        np.copyto(out, 0)

    l_ans = jt.numpy_code(
        [b.shape],
        [b.dtype],
        [a, b],
        forward_code,
        [backward_code1, backward_code2],
    )
    ans = l_ans[0]
    return ans
Ejemplo n.º 19
0
    def test(self):
        def forward_code(np, data):
            a = data["inputs"][0]
            b = data["outputs"][0]
            np.add(a, a, out=b)

        def backward_code(np, data):
            dout = data["dout"]
            out = data["outputs"][0]
            np.copyto(out, dout * 2.0)

        a = jt.random((5, 1))
        b = jt.numpy_code(
            a.shape,
            a.dtype,
            [a],
            forward_code,
            [backward_code],
        )
        assert np.allclose(b.data, (a + a).data)
        da = jt.grad(b, a)
        one = np.ones(a.shape)
        assert np.allclose(da.data, one * 2.0)
Ejemplo n.º 20
0
def slogdet(x):
    def forward_code(np, data):
        a = data["inputs"][0]
        sign, m_a = data["outputs"]
        sign_, t_a = np.linalg.slogdet(a)
        np.copyto(m_a, t_a)
        np.copyto(sign, sign_)

    def backward_code(np, data):
        def T(x):
            return np.swapaxes(x, -1, -2)

        _dot = partial(np.einsum, '...ij,...jk->...ik')
        dout = data["dout"]
        out = data["outputs"][0]
        inp = data["inputs"][0]
        out_index = data["out_index"]
        if out_index == 0:
            np.copyto(out, 0)
        if out_index == 1:
            t = np.reshape(dout, np.shape(dout) + (1, 1))
            t = t * T(np.linalg.inv(inp))
            np.copyto(out, t)

    s = x.shape
    det_s = s[:-2]
    if len(det_s) == 0:
        det_s.append(1)
    sign, mx = jt.numpy_code(
        [det_s, det_s],
        [x.dtype, x.dtype],
        [x],
        forward_code,
        [backward_code],
    )
    return sign, mx
Ejemplo n.º 21
0
def svd(x):
    r'''
    calculate the Singular Value Decomposition of x.It follows the below fomula:
    x = usv*
    only support full matrices == False ver now, which means:
    x's shape (...,M,K)
    u's shape (...,M,K)
    s's shape (...,K)
    v's shape (...,K,N)
    where K is min(M,N).
    :param x:
    :return:u,s,v.
    '''
    def forward_code(np, data):
        a = data["inputs"][0]
        u, s, v = data["outputs"]
        #TODO:remove copyto
        tu, ts, tv = np.linalg.svd(a, full_matrices=0)
        np.copyto(u, tu)
        np.copyto(s, ts)
        np.copyto(v, tv)

    def backward_code(np, data):
        def T(x):
            return np.swapaxes(x, -1, -2)

        _dot = partial(np.einsum, '...ij,...jk->...ik')
        dout = data["dout"]
        out = data["outputs"][0]
        inp = data["inputs"][0]
        out_index = data["out_index"]
        u, s, v = data["f_outputs"]
        v = T(v)
        m, n = inp.shape[-2:]
        k = np.min((m, n))
        i = np.reshape(
            np.eye(k),
            np.concatenate((np.ones(inp.ndim - 2, dtype=int), (k, k))))
        if out_index == 0:
            f = 1 / (s[..., np.newaxis, :]**2 - s[..., :, np.newaxis]**2 + i)
            gu = dout
            utgu = _dot(T(u), gu)
            t = (f * (utgu - T(utgu))) * s[..., np.newaxis, :]
            t = _dot(_dot(u, t), T(v))
            if m > n:
                i_minus_uut = (np.reshape(
                    np.eye(m),
                    np.concatenate((np.ones(inp.ndim - 2, dtype=int),
                                    (m, m)))) - _dot(u, np.conj(T(u))))
                t = t + T(
                    _dot(_dot(v / s[..., np.newaxis, :], T(gu)), i_minus_uut))
            np.copyto(out, t)
        elif out_index == 1:
            gs = dout
            t = i * gs[..., :, np.newaxis]
            t = _dot(_dot(u, t), T(v))
            np.copyto(out, t)
        elif out_index == 2:
            f = 1 / (s[..., np.newaxis, :]**2 - s[..., :, np.newaxis]**2 + i)
            gv = dout
            vtgv = _dot(T(v), gv)
            t = s[..., :, np.newaxis] * (f * (vtgv - T(vtgv)))
            t = _dot(_dot(u, t), T(v))
            if m < n:
                i_minus_vvt = (np.reshape(
                    np.eye(n),
                    np.concatenate((np.ones(inp.ndim - 2, dtype=int),
                                    (n, n)))) - _dot(v, np.conj(T(v))))
                t = t + T(
                    _dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt))
            np.copyto(out, t)

    m, n = x.shape[-2:]
    k = min(m, n)
    s1 = list(x.shape)
    s1[-1] = k
    s2 = list(x.shape)
    s2[-2] = k
    s3 = list(x.shape)[:-2]
    s3.append(k)
    u, s, v = jt.numpy_code(
        [s1, s3, s2],
        [x.dtype, x.dtype, x.dtype],
        [x],
        forward_code,
        [backward_code],
    )
    return u, s, v
Ejemplo n.º 22
0
def svd(x):
    def forward_code(np, data):
        a = data["inputs"][0]
        u, s, v = data["outputs"]
        #TODO:remove copyto
        tu, ts, tv = np.linalg.svd(a, full_matrices=0)
        np.copyto(u, tu)
        np.copyto(s, ts)
        np.copyto(v, tv)

    def backward_code(np, data):
        def T(x):
            return np.swapaxes(x, -1, -2)

        _dot = partial(np.einsum, '...ij,...jk->...ik')
        dout = data["dout"]
        out = data["outputs"][0]
        inp = data["inputs"][0]
        out_index = data["out_index"]
        u, s, v = data["f_outputs"]
        v = T(v)
        m, n = inp.shape[-2:]
        k = np.min((m, n))
        i = np.reshape(
            np.eye(k),
            np.concatenate((np.ones(inp.ndim - 2, dtype=int), (k, k))))
        if out_index == 0:
            f = 1 / (s[..., np.newaxis, :]**2 - s[..., :, np.newaxis]**2 + i)
            gu = dout
            utgu = _dot(T(u), gu)
            t = (f * (utgu - T(utgu))) * s[..., np.newaxis, :]
            t = _dot(_dot(u, t), T(v))
            if m > n:
                i_minus_uut = (np.reshape(
                    np.eye(m),
                    np.concatenate((np.ones(inp.ndim - 2, dtype=int),
                                    (m, m)))) - _dot(u, np.conj(T(u))))
                t = t + T(
                    _dot(_dot(v / s[..., np.newaxis, :], T(gu)), i_minus_uut))
            np.copyto(out, t)
        elif out_index == 1:
            gs = dout
            t = i * gs[..., :, np.newaxis]
            t = _dot(_dot(u, t), T(v))
            np.copyto(out, t)
        elif out_index == 2:
            f = 1 / (s[..., np.newaxis, :]**2 - s[..., :, np.newaxis]**2 + i)
            gv = dout
            vtgv = _dot(T(v), gv)
            t = s[..., :, np.newaxis] * (f * (vtgv - T(vtgv)))
            t = _dot(_dot(u, t), T(v))
            if m < n:
                i_minus_vvt = (np.reshape(
                    np.eye(n),
                    np.concatenate((np.ones(inp.ndim - 2, dtype=int),
                                    (n, n)))) - _dot(v, np.conj(T(v))))
                t = t + T(
                    _dot(_dot(u / s[..., np.newaxis, :], T(gv)), i_minus_vvt))
            np.copyto(out, t)

    m, n = x.shape[-2:]
    k = min(m, n)
    s1 = list(x.shape)
    s1[-1] = k
    s2 = list(x.shape)
    s2[-2] = k
    s3 = list(x.shape)[:-2]
    s3.append(k)
    u, s, v = jt.numpy_code(
        [s1, s3, s2],
        [x.dtype, x.dtype, x.dtype],
        [x],
        forward_code,
        [backward_code],
    )
    return u, s, v