Exemple #1
0
def sigma2_correction(d, order, n_v, n_s, f1_A, f1_B, f_2, h_2, f_ss, Sigma):

    resp = dict()

    if order >= 2:
        g_a = d['g_a']
        g_e = d['g_e']
        g_aa = d['g_aa']
        g_ae = d['g_ae']
        g_ee = d['g_ee']
        I = np.eye(n_v, n_v)
        A_inv = np.linalg.inv(sdot(f1_A, g_a + I) + f1_B)
        V_s = np.zeros((n_v * 3 + n_s + 1, ))
        V_s[-1] = 1
        K_ss = mdot(f_2[:, :n_v, :n_v], [g_e, g_e]) + sdot(f1_A, g_ee)
        rhs = np.tensordot(K_ss, Sigma) + mdot(h_2, [V_s, V_s])
        sigma2 = sdot(A_inv, -rhs) / 2

        ds2 = sigma2 - d['g_ss']
        resp['sigma2'] = sigma2

    if order == 3:
        A = sdot(f1_A, g_a) + f1_B
        A_inv = np.linalg.inv(A)
        B = f1_A
        C = g_a
        V_x = np.row_stack([np.dot(g_a, g_a), g_a, I, np.zeros((n_s, n_v))])

        D = mdot(f_ss, [V_x]) + sdot(f1_A, mdot(g_aa, [g_a, ds2]))
        dsigma2 = solve_sylvester(A, B, C, D)
        resp['dsigma2'] = dsigma2
        resp['D'] = D
        resp['f_ss'] = f_ss
    return resp
def sigma2_correction(d,order,n_v,n_s,f1_A,f1_B, f_2, h_2, f_ss, Sigma):

    resp = dict()
    
    if order >= 2:
        g_a = d['g_a']
        g_e = d['g_e']
        g_aa = d['g_aa']
        g_ae = d['g_ae']
        g_ee = d['g_ee']
        I = np.eye(n_v,n_v)
        A_inv = np.linalg.inv( sdot(f1_A,g_a+I) + f1_B )
        V_s = np.zeros( (n_v*3+n_s+1,))
        V_s[-1] = 1
        K_ss = mdot(f_2[:,:n_v,:n_v],[g_e,g_e]) + sdot( f1_A, g_ee )
        rhs = np.tensordot( K_ss,Sigma) + mdot(h_2,[V_s,V_s])
        sigma2 = sdot(A_inv,-rhs) / 2

        ds2 = sigma2 - d['g_ss']
        resp['sigma2'] = sigma2

    if order == 3:
        A = sdot(f1_A,g_a) + f1_B
        A_inv =  np.linalg.inv( A )
        B = f1_A
        C = g_a
        V_x = np.row_stack( [
            np.dot(g_a,g_a),
            g_a,
            I,
            np.zeros((n_s,n_v))
        ])

        D = mdot( f_ss, [V_x]) + sdot( f1_A, mdot( g_aa, [ g_a , ds2 ])  )
        print 'D'
        print D
        dsigma2 = solve_sylvester(A,B,C,D)
        resp['dsigma2'] = dsigma2
        resp['D'] = D
        resp['f_ss'] = f_ss
    return resp
def state_perturb(f_fun, g_fun, sigma, sigma2_correction=None, verbose=True):
    """Computes a Taylor approximation of decision rules, given the supplied derivatives.

    The original system is assumed to be in the the form:

    .. math::

        E_t f(s_t,x_t,s_{t+1},x_{t+1})

        s_t = g(s_{t-1},x_{t-1}, \\lambda \\epsilon_t)

    where :math:`\\lambda` is a scalar scaling down the risk.  the solution is a function :math:`\\varphi` such that:

    .. math::

        x_t = \\varphi ( s_t, \\sigma )

    The user supplies, a list of derivatives of f and g.

    :param f_fun: list of derivatives of f [order0, order1, order2, ...]
    :param g_fun: list of derivatives of g [order0, order1, order2, ...]
    :param sigma: covariance matrix of :math:`\\epsilon_t`
    :param sigma2_correction: (optional) first and second derivatives of g w.r.t. sigma if :math:`g` explicitely depends
        :math:`sigma`


    Assuming :math:`s_t` ,  :math:`x_t` and :math:`\\epsilon_t` are vectors of size
    :math:`n_s`, :math:`n_x`  and :math:`n_x`  respectively.
    In general the derivative of order :math:`i` of :math:`f`  is a multimensional array of size :math:`n_x \\times (N, ..., N)`
    with :math:`N=2(n_s+n_x)` repeated :math:`i` times (possibly 0).
    Similarly the derivative of order :math:`i` of :math:`g`  is a multidimensional array of size :math:`n_s \\times (M, ..., M)`
    with :math:`M=n_s+n_x+n_2` repeated :math:`i` times (possibly 0).



    """

    import numpy as np
    from numpy.linalg import solve
    from dolo.algos.dtcscc.perturbations import GeneralizedEigenvaluesError, GeneralizedEigenvaluesDefinition, GeneralizedEigenvaluesSelectionError, qzordered

    approx_order = len(f_fun) - 1  # order of approximation

    [f0, f1] = f_fun[:2]

    [g0, g1] = g_fun[:2]
    n_x = f1.shape[0]  # number of controls
    n_s = f1.shape[1] // 2 - n_x  # number of states
    n_e = g1.shape[1] - n_x - n_s
    n_v = n_s + n_x

    f_s = f1[:, :n_s]
    f_x = f1[:, n_s:n_s + n_x]
    f_snext = f1[:, n_v:n_v + n_s]
    f_xnext = f1[:, n_v + n_s:]

    g_s = g1[:, :n_s]
    g_x = g1[:, n_s:n_s + n_x]
    g_e = g1[:, n_v:]

    A = np.row_stack([
        np.column_stack([np.eye(n_s), np.zeros((n_s, n_x))]),
        np.column_stack([-f_snext, -f_xnext])
    ])
    B = np.row_stack(
        [np.column_stack([g_s, g_x]),
         np.column_stack([f_s, f_x])])

    [S, T, Q, Z, eigval] = qzordered(A, B, 1.0 - 1e-8)

    Q = Q.real  # is it really necessary ?
    Z = Z.real

    diag_S = np.diag(S)
    diag_T = np.diag(T)

    tol_geneigvals = 1e-10

    try:
        ok = sum((abs(diag_S) < tol_geneigvals) *
                 (abs(diag_T) < tol_geneigvals)) == 0
        assert (ok)
    except Exception as e:
        raise GeneralizedEigenvaluesError(diag_S=diag_S, diag_T=diag_T)

    if max(eigval[:n_s]) >= 1 and min(eigval[n_s:]) < 1:
        # BK conditions are met
        pass
    else:
        eigval_s = sorted(eigval, reverse=True)
        ev_a = eigval_s[n_s - 1]
        ev_b = eigval_s[n_s]
        cutoff = (ev_a - ev_b) / 2
        if not ev_a > ev_b:
            raise GeneralizedEigenvaluesSelectionError(A=A,
                                                       B=B,
                                                       eigval=eigval,
                                                       cutoff=cutoff,
                                                       diag_S=diag_S,
                                                       diag_T=diag_T,
                                                       n_states=n_s)
        import warnings
        if cutoff > 1:
            warnings.warn("Solution is not convergent.")
        else:
            warnings.warn(
                "There are multiple convergent solutions. The one with the smaller eigenvalues was selected."
            )
        [S, T, Q, Z, eigval] = qzordered(A, B, cutoff)

    Z11 = Z[:n_s, :n_s]
    # Z12 = Z[:n_s, n_s:]
    Z21 = Z[n_s:, :n_s]
    # Z22 = Z[n_s:, n_s:]
    S11 = S[:n_s, :n_s]
    T11 = T[:n_s, :n_s]

    # first order solution
    C = solve(Z11.T, Z21.T).T
    P = np.dot(solve(S11.T, Z11.T).T, solve(Z11.T, T11.T).T)
    Q = g_e

    if False:
        from numpy import dot
        test = f_s + f_x @ C + f_snext @ (g_s + g_x @ C) + f_xnext @ C @ (
            g_s + g_x @ C)
        print('Error: ' + str(abs(test).max()))

    if approx_order == 1:
        return [C]

    # second order solution
    from dolo.numeric.tensor import sdot, mdot
    from numpy import dot
    from dolo.numeric.matrix_equations import solve_sylvester

    f2 = f_fun[2]
    g2 = g_fun[2]
    g_ss = g2[:, :n_s, :n_s]
    g_sx = g2[:, :n_s, n_s:n_v]
    g_xx = g2[:, n_s:n_v, n_s:n_v]

    X_s = C

    V1_3 = g_s + dot(g_x, X_s)
    V1 = np.row_stack([np.eye(n_s), X_s, V1_3, X_s @ V1_3])

    K2 = g_ss + 2 * sdot(g_sx, X_s) + mdot(g_xx, [X_s, X_s])
    A = f_x + dot(f_snext + dot(f_xnext, X_s), g_x)
    B = f_xnext
    C = V1_3
    D = mdot(f2, [V1, V1]) + sdot(f_snext + dot(f_xnext, X_s), K2)

    X_ss = solve_sylvester(A, B, C, D)

    #    test = sdot( A, X_ss ) + sdot( B,  mdot(X_ss,[V1_3,V1_3]) ) + D
    if True:
        g_ee = g2[:, n_v:, n_v:]

        v = np.row_stack([g_e, dot(X_s, g_e)])

        K_tt = mdot(f2[:, n_v:, n_v:], [v, v])
        K_tt += sdot(f_snext + dot(f_xnext, X_s), g_ee)
        K_tt += mdot(sdot(f_xnext, X_ss), [g_e, g_e])
        K_tt = np.tensordot(K_tt, sigma, axes=((1, 2), (0, 1)))

        if sigma2_correction is not None:
            K_tt += sdot(f_snext + dot(f_xnext, X_s), sigma2_correction[0])

        L_tt = f_x + dot(f_snext, g_x) + dot(f_xnext,
                                             dot(X_s, g_x) + np.eye(n_x))
        X_tt = solve(L_tt, -K_tt)

    if approx_order == 2:
        return [[X_s, X_ss], [X_tt]]

    # third order solution

    f3 = f_fun[3]
    g3 = g_fun[3]
    g_sss = g3[:, :n_s, :n_s, :n_s]
    g_ssx = g3[:, :n_s, :n_s, n_s:n_v]
    g_sxx = g3[:, :n_s, n_s:n_v, n_s:n_v]
    g_xxx = g3[:, n_s:n_v, n_s:n_v, n_s:n_v]

    V2_3 = K2 + sdot(g_x, X_ss)
    V2 = np.row_stack([
        np.zeros((n_s, n_s, n_s)), X_ss, V2_3,
        dot(X_s, V2_3) + mdot(X_ss, [V1_3, V1_3])
    ])

    K3 = g_sss + 3 * sdot(g_ssx, X_s) + 3 * mdot(g_sxx, [X_s, X_s]) + 2 * sdot(
        g_sx, X_ss)
    K3 += 3 * mdot(g_xx, [X_ss, X_s]) + mdot(g_xxx, [X_s, X_s, X_s])
    L3 = 3 * mdot(X_ss, [V1_3, V2_3])

    # A = f_x + dot( f_snext + dot(f_xnext,X_s), g_x) # same as before
    # B = f_xnext # same
    # C = V1_3 # same
    D = mdot(f3, [V1, V1, V1]) + 3 * mdot(f2, [V2, V1]) + sdot(
        f_snext + dot(f_xnext, X_s), K3)
    D += sdot(f_xnext, L3)

    X_sss = solve_sylvester(A, B, C, D)

    # now doing sigma correction with sigma replaced by l in the subscripts

    # if not sigma is None:
    if True:
        g_se = g2[:, :n_s, n_v:]
        g_xe = g2[:, n_s:n_v, n_v:]

        g_see = g3[:, :n_s, n_v:, n_v:]
        g_xee = g3[:, n_s:n_v, n_v:, n_v:]

        W_l = np.row_stack([g_e, dot(X_s, g_e)])

        I_e = np.eye(n_e)

        V_sl = g_se + mdot(g_xe, [X_s, np.eye(n_e)])

        W_sl = np.row_stack([V_sl, mdot(X_ss, [V1_3, g_e]) + sdot(X_s, V_sl)])

        K_ee = mdot(f3[:, :, n_v:, n_v:], [V1, W_l, W_l])
        K_ee += 2 * mdot(f2[:, n_v:, n_v:], [W_sl, W_l])

        # stochastic part of W_ll

        SW_ll = np.row_stack([g_ee, mdot(X_ss, [g_e, g_e]) + sdot(X_s, g_ee)])

        DW_ll = np.concatenate(
            [X_tt, dot(g_x, X_tt),
             dot(X_s, sdot(g_x, X_tt)) + X_tt])

        K_ee += mdot(f2[:, :, n_v:], [V1, SW_ll])

        K_ = np.tensordot(K_ee, sigma, axes=((2, 3), (0, 1)))

        K_ += mdot(f2[:, :, n_s:], [V1, DW_ll])

        def E(vec):
            n = len(vec.shape)
            return np.tensordot(vec, sigma, axes=((n - 2, n - 1), (0, 1)))

        L = sdot(g_sx, X_tt) + mdot(g_xx, [X_s, X_tt])

        L += E(g_see + mdot(g_xee, [X_s, I_e, I_e]))

        M = E(mdot(X_sss, [V1_3, g_e, g_e]) + 2 * mdot(X_ss, [V_sl, g_e]))
        M += mdot(X_ss, [V1_3, E(g_ee) + sdot(g_x, X_tt)])

        A = f_x + dot(f_snext + dot(f_xnext, X_s), g_x)  # same as before
        B = f_xnext  # same
        C = V1_3  # same
        D = K_ + dot(f_snext + dot(f_xnext, X_s), L) + dot(f_xnext, M)

        if sigma2_correction is not None:
            g_sl = sigma2_correction[1][:, :n_s]
            g_xl = sigma2_correction[1][:, n_s:(n_s + n_x)]
            D += dot(f_snext + dot(f_xnext, X_s),
                     g_sl + dot(g_xl, X_s))  # constant

        X_stt = solve_sylvester(A, B, C, D)

    if approx_order == 3:
        # if sigma is None:
        #     return [X_s,X_ss,X_sss]
        # else:
        #     return [[X_s,X_ss,X_sss],[X_tt, X_stt]]
        return [[X_s, X_ss, X_sss], [X_tt, X_stt]]
def state_perturb(f_fun, g_fun, sigma, sigma2_correction=None, verbose=True, eigmax=1.00000):
    """Computes a Taylor approximation of decision rules, given the supplied derivatives.

    The original system is assumed to be in the the form:

    .. math::

        E_t f(s_t,x_t,s_{t+1},x_{t+1})

        s_t = g(s_{t-1},x_{t-1}, \\lambda \\epsilon_t)

    where :math:`\\lambda` is a scalar scaling down the risk.  the solution is a function :math:`\\varphi` such that:

    .. math::

        x_t = \\varphi ( s_t, \\sigma )

    The user supplies, a list of derivatives of f and g.

    :param f_fun: list of derivatives of f [order0, order1, order2, ...]
    :param g_fun: list of derivatives of g [order0, order1, order2, ...]
    :param sigma: covariance matrix of :math:`\\epsilon_t`
    :param sigma2_correction: (optional) first and second derivatives of g w.r.t. sigma if :math:`g` explicitely depends
        :math:`sigma`


    Assuming :math:`s_t` ,  :math:`x_t` and :math:`\\epsilon_t` are vectors of size
    :math:`n_s`, :math:`n_x`  and :math:`n_x`  respectively.
    In general the derivative of order :math:`i` of :math:`f`  is a multimensional array of size :math:`n_x \\times (N, ..., N)`
    with :math:`N=2(n_s+n_x)` repeated :math:`i` times (possibly 0).
    Similarly the derivative of order :math:`i` of :math:`g`  is a multidimensional array of size :math:`n_s \\times (M, ..., M)`
    with :math:`M=n_s+n_x+n_2` repeated :math:`i` times (possibly 0).



    """

    import numpy as np
    from numpy.linalg import solve

    approx_order = len(f_fun) - 1 # order of approximation

    [f0,f1] = f_fun[:2]

    [g0,g1] = g_fun[:2]
    n_x = f1.shape[0]           # number of controls
    n_s = f1.shape[1]/2 - n_x   # number of states
    n_e = g1.shape[1] - n_x - n_s
    n_v = n_s + n_x

    f_s = f1[:,:n_s]
    f_x = f1[:,n_s:n_s+n_x]
    f_snext = f1[:,n_v:n_v+n_s]
    f_xnext = f1[:,n_v+n_s:]

    g_s = g1[:,:n_s]
    g_x = g1[:,n_s:n_s+n_x]
    g_e = g1[:,n_v:]

    A = np.row_stack([
        np.column_stack( [ np.eye(n_s), np.zeros((n_s,n_x)) ] ),
        np.column_stack( [ -f_snext    , -f_xnext             ] )
    ])
    B = np.row_stack([
        np.column_stack( [ g_s, g_x ] ),
        np.column_stack( [ f_s, f_x ] )
    ])



    from dolo.numeric.extern.qz import qzordered
    [S,T,Q,Z,eigval] = qzordered(A,B,n_s)

    # Check Blanchard=Kahn conditions
    n_big_one = sum(eigval>eigmax)
    n_expected = n_x
    if verbose:
        print( "There are {} eigenvalues greater than 1. Expected: {}.".format( n_big_one, n_x ) )

    if n_big_one != n_expected:
        raise Exception("There should be exactly {} eigenvalues greater than one. Not {}.".format(n_x, n_big_one))

    Q = Q.real # is it really necessary ?
    Z = Z.real

    Z11 = Z[:n_s,:n_s]
    Z12 = Z[:n_s,n_s:]
    Z21 = Z[n_s:,:n_s]
    Z22 = Z[n_s:,n_s:]
    S11 = S[:n_s,:n_s]
    T11 = T[:n_s,:n_s]

    # first order solution
    C = solve(Z11.T, Z21.T).T
    P = np.dot(solve(S11.T, Z11.T).T , solve(Z11.T, T11.T).T )
    Q = g_e

    if False:
        from numpy import dot
        test = f_s + dot(f_x,C) + dot( f_snext, g_s + dot(g_x,C) ) + dot(f_xnext, dot( C, g_s + dot(g_x,C) ) )
        print('Error: ' + str(abs(test).max()))

    if approx_order == 1:
        return [C]

    # second order solution
    from dolo.numeric.tensor import sdot, mdot
    from numpy import dot
    from dolo.numeric.matrix_equations import solve_sylvester

    f2 = f_fun[2]
    g2 = g_fun[2]
    g_ss = g2[:,:n_s,:n_s]
    g_sx = g2[:,:n_s,n_s:n_v]
    g_xx = g2[:,n_s:n_v,n_s:n_v]

    X_s = C



    V1_3 = g_s + dot(g_x,X_s)
    V1 = np.row_stack([
        np.eye(n_s),
        X_s,
        V1_3,
        dot( X_s, V1_3 )
    ])

    K2 = g_ss + 2 * sdot(g_sx,X_s) + mdot(g_xx,[X_s,X_s])
    #L2 =
    A = f_x + dot( f_snext + dot(f_xnext,X_s), g_x)
    B = f_xnext
    C = V1_3
    D = mdot(f2,[V1,V1]) + sdot(f_snext + dot(f_xnext,X_s),K2)

    X_ss = solve_sylvester(A,B,C,D)

#    test = sdot( A, X_ss ) + sdot( B,  mdot(X_ss,[V1_3,V1_3]) ) + D


    # if sigma is not None:
    if True:
        g_ee = g2[:,n_v:,n_v:]

        v = np.row_stack([
            g_e,
            dot(X_s,g_e)
        ])

        K_tt = mdot( f2[:,n_v:,n_v:], [v,v] )
        K_tt += sdot( f_snext + dot(f_xnext,X_s) , g_ee )
        K_tt += mdot( sdot( f_xnext, X_ss), [g_e, g_e] )
        K_tt = np.tensordot( K_tt, sigma, axes=((1,2),(0,1)))

        if sigma2_correction is not None:
            K_tt += sdot( f_snext + dot(f_xnext,X_s) , sigma2_correction[0] )

        L_tt = f_x  + dot(f_snext, g_x) + dot(f_xnext, dot(X_s, g_x) + np.eye(n_x) )
        X_tt = solve( L_tt, - K_tt)

    if approx_order == 2:
        return [[X_s,X_ss],[X_tt]]

    # third order solution

    f3 = f_fun[3]
    g3 = g_fun[3]
    g_sss = g3[:,:n_s,:n_s,:n_s]
    g_ssx = g3[:,:n_s,:n_s,n_s:n_v]
    g_sxx = g3[:,:n_s,n_s:n_v,n_s:n_v]
    g_xxx = g3[:,n_s:n_v,n_s:n_v,n_s:n_v]

    V2_3 = K2 + sdot(g_x,X_ss)
    V2 = np.row_stack([
        np.zeros( (n_s,n_s,n_s) ),
        X_ss,
        V2_3,
        dot( X_s, V2_3 ) + mdot(X_ss,[V1_3,V1_3])
    ])

    K3 = g_sss + 3*sdot(g_ssx,X_s) + 3*mdot(g_sxx,[X_s,X_s]) + 2*sdot(g_sx,X_ss)
    K3 += 3*mdot( g_xx,[X_ss,X_s] ) + mdot(g_xxx,[X_s,X_s,X_s])
    L3 = 3*mdot(X_ss,[V1_3,V2_3])

    # A = f_x + dot( f_snext + dot(f_xnext,X_s), g_x) # same as before
    # B = f_xnext # same
    # C = V1_3 # same
    D = mdot(f3,[V1,V1,V1]) + 3*mdot(f2,[ V2,V1 ]) + sdot(f_snext + dot(f_xnext,X_s),K3)
    D += sdot( f_xnext, L3 )

    X_sss = solve_sylvester(A,B,C,D)

    # now doing sigma correction with sigma replaced by l in the subscripts

    # if not sigma is None:
    if True:
        g_se= g2[:,:n_s,n_v:]
        g_xe= g2[:,n_s:n_v,n_v:]

        g_see= g3[:,:n_s,n_v:,n_v:]
        g_xee= g3[:,n_s:n_v,n_v:,n_v:]


        W_l = np.row_stack([
            g_e,
            dot(X_s,g_e)
        ])

        I_e = np.eye(n_e)

        V_sl = g_se + mdot( g_xe, [X_s, np.eye(n_e)])

        W_sl = np.row_stack([
            V_sl,
            mdot( X_ss, [ V1_3, g_e ] ) + sdot( X_s, V_sl)
        ])

        K_ee = mdot(f3[:,:,n_v:,n_v:], [V1, W_l, W_l ])
        K_ee += 2 * mdot( f2[:,n_v:,n_v:], [W_sl, W_l])

        # stochastic part of W_ll

        SW_ll = np.row_stack([
            g_ee,
            mdot(X_ss, [g_e, g_e]) + sdot(X_s, g_ee)
        ])

        DW_ll = np.concatenate([
            X_tt,
            dot(g_x, X_tt),
            dot(X_s, sdot(g_x,X_tt )) + X_tt
        ])

        K_ee += mdot( f2[:,:,n_v:], [V1, SW_ll])

        K_ = np.tensordot(K_ee, sigma, axes=((2,3),(0,1)))

        K_ += mdot(f2[:,:,n_s:], [V1, DW_ll])

        def E(vec):
            n = len(vec.shape)
            return np.tensordot(vec,sigma,axes=((n-2,n-1),(0,1)))

        L = sdot(g_sx,X_tt) + mdot(g_xx,[X_s,X_tt])

        L += E(g_see + mdot(g_xee,[X_s,I_e,I_e]) )

        M = E( mdot(X_sss,[V1_3, g_e, g_e]) + 2*mdot(X_ss, [V_sl,g_e]) )
        M += mdot( X_ss, [V1_3, E( g_ee ) + sdot(g_x, X_tt)] )


        A = f_x + dot( f_snext + dot(f_xnext,X_s), g_x) # same as before
        B = f_xnext # same
        C = V1_3 # same
        D = K_ + dot( f_snext + dot(f_xnext,X_s), L) + dot( f_xnext, M )

        if sigma2_correction is not None:
            g_sl = sigma2_correction[1][:,:n_s]
            g_xl = sigma2_correction[1][:,n_s:(n_s+n_x)]
            D += dot( f_snext + dot(f_xnext,X_s), g_sl + dot(g_xl,X_s) )   # constant

        X_stt = solve_sylvester(A,B,C,D)

    if approx_order == 3:
        # if sigma is None:
        #     return [X_s,X_ss,X_sss]
        # else:
        #     return [[X_s,X_ss,X_sss],[X_tt, X_stt]]
        return [[X_s,X_ss,X_sss],[X_tt, X_stt]]
Exemple #5
0
def new_solver_with_p(derivatives, sizes, max_order=2):

    if max_order == 1:
        [f_0, f_1] = derivatives
    elif max_order == 2:
        [f_0, f_1, f_2] = derivatives
    elif max_order == 3:
        [f_0, f_1, f_2, f_3] = derivatives
    diff = derivatives

    f = diff
    #n = f[0].shape[0] # number of variables
    #s = f[1].shape[1] - 3*n
    [n_v, n_s, n_p] = sizes

    n = n_v

    f1_A = f[1][:, :n]
    f1_B = f[1][:, n:(2 * n)]
    f1_C = f[1][:, (2 * n):(3 * n)]
    f1_D = f[1][:, (3 * n):((3 * n) + n_s)]
    f1_E = f[1][:, ((3 * n) + n_s):]

    ## first order
    [ev, g_x] = second_order_solver(f1_A, f1_B, f1_C)

    mm = np.dot(f1_A, g_x) + f1_B

    g_u = -np.linalg.solve(mm, f1_D)
    g_p = -np.linalg.solve(mm, f1_E)

    d = {'ev': ev, 'g_a': g_x, 'g_e': g_u, 'g_p': g_p}

    if max_order == 1:
        return d

    # we need it for higher order

    V_x = np.concatenate([
        np.dot(g_x, g_x), g_x,
        np.eye(n_v),
        np.zeros((n_s, n_v)),
        np.zeros((n_p, n_v))
    ])
    V_u = np.concatenate([
        np.dot(g_x, g_u), g_u,
        np.zeros((n_v, n_s)),
        np.eye(n_s),
        np.zeros((n_p, n_s))
    ])
    V_p = np.concatenate([
        np.dot(g_x, g_p), g_p,
        np.zeros((n_v, n_p)),
        np.zeros((n_s, n_p)),
        np.eye(n_p)
    ])
    V = [None, [V_x, V_u]]

    # Translation
    n_a = n_v
    n_e = n_s
    n_p = g_p.shape[1]

    f_1 = f[1]
    f_2 = f[2]
    f_d = f1_A
    f_a = f1_B
    f_h = f1_C
    f_u = f1_D
    V_a = V_x
    V_e = V_u
    g_a = g_x
    g_e = g_u

    # Once for all !
    A = f_a + sdot(f_d, g_a)
    B = f_d
    C = g_a
    A_inv = np.linalg.inv(A)

    #----------Computing order 2

    order = 2

    #--- Computing derivatives ('a', 'a')

    K_aa = +mdot(f_2, [V_a, V_a])
    L_aa = np.zeros((n_v, n_a, n_a))

    #We need to solve the infamous sylvester equation
    #A = f_a + sdot(f_d,g_a)
    #B = f_d
    #C = g_a
    D = K_aa + sdot(f_d, L_aa)
    g_aa = solve_sylvester(A, B, C, D)

    if order < max_order:
        Y = L_aa + mdot(g_a, [g_aa]) + mdot(g_aa, [g_a, g_a])
        Z = g_aa
        V_aa = build_V(Y, Z, (n_a, n_e, n_p))

    #--- Computing derivatives ('a', 'e')

    K_ae = +mdot(f_2, [V_a, V_e])
    L_ae = +mdot(g_aa, [g_a, g_e])

    #We solve A*X + const = 0
    const = sdot(f_d, L_ae) + K_ae
    g_ae = -sdot(A_inv, const)

    if order < max_order:
        Y = L_ae + mdot(g_a, [g_ae])
        Z = g_ae
        V_ae = build_V(Y, Z, (n_a, n_e, n_p))

    #--- Computing derivatives ('a', 'p')

    K_ap = +mdot(f_2, [V_a, V_p])
    L_ap = +mdot(g_aa, [g_a, g_p])

    #We solve A*X + const = 0
    const = sdot(f_d, L_ap) + K_ap
    g_ap = -sdot(A_inv, const)

    if order < max_order:
        Y = L_ap + mdot(g_a, [g_ap])
        Z = g_ap
        V_ap = build_V(Y, Z, (n_a, n_e, n_p))

    #--- Computing derivatives ('e', 'e')

    K_ee = +mdot(f_2, [V_e, V_e])
    L_ee = +mdot(g_aa, [g_e, g_e])

    #We solve A*X + const = 0
    const = sdot(f_d, L_ee) + K_ee
    g_ee = -sdot(A_inv, const)

    if order < max_order:
        Y = L_ee + mdot(g_a, [g_ee])
        Z = g_ee
        V_ee = build_V(Y, Z, (n_a, n_e, n_p))

    #--- Computing derivatives ('e', 'p')

    K_ep = +mdot(f_2, [V_e, V_p])
    L_ep = +mdot(g_aa, [g_e, g_p])

    #We solve A*X + const = 0
    const = sdot(f_d, L_ep) + K_ep
    g_ep = -sdot(A_inv, const)

    if order < max_order:
        Y = L_ep + mdot(g_a, [g_ep])
        Z = g_ep
        V_ep = build_V(Y, Z, (n_a, n_e, n_p))

    #--- Computing derivatives ('p', 'p')

    K_pp = +mdot(f_2, [V_p, V_p])
    L_pp = +mdot(g_aa, [g_p, g_p])

    #We solve A*X + const = 0
    const = sdot(f_d, L_pp) + K_pp
    g_pp = -sdot(A_inv, const)

    if order < max_order:
        Y = L_pp + mdot(g_a, [g_pp])
        Z = g_pp
        V_pp = build_V(Y, Z, (n_a, n_e, n_p))

    d.update({
        'g_aa': g_aa,
        'g_ae': g_ae,
        'g_ee': g_ee,
        'g_ap': g_ap,
        'g_ep': g_ep,
        'g_pp': g_pp
    })
    if max_order == 2:
        return d

    #----------Computing order 3

    order = 3

    #--- Computing derivatives ('a', 'a', 'a')

    K_aaa = +3 * mdot(f_2, [V_a, V_aa]) + mdot(f_3, [V_a, V_a, V_a])
    L_aaa = +3 * mdot(g_aa, [g_a, g_aa])

    #We need to solve the infamous sylvester equation
    #A = f_a + sdot(f_d,g_a)
    #B = f_d
    #C = g_a
    D = K_aaa + sdot(f_d, L_aaa)
    g_aaa = solve_sylvester(A, B, C, D)

    if order < max_order:
        Y = L_aaa + mdot(g_a, [g_aaa]) + mdot(g_aaa, [g_a, g_a, g_a])
        Z = g_aaa
        V_aaa = build_V(Y, Z, (n_a, n_e, n_p))

    #--- Computing derivatives ('a', 'a', 'e')

    K_aae = +mdot(f_2, [V_aa, V_e]) + 2 * mdot(f_2, [V_a, V_ae]) + mdot(
        f_3, [V_a, V_a, V_e])
    L_aae = +mdot(g_aa, [g_aa, g_e]) + 2 * mdot(g_aa, [g_a, g_ae]) + mdot(
        g_aaa, [g_a, g_a, g_e])

    #We solve A*X + const = 0
    const = sdot(f_d, L_aae) + K_aae
    g_aae = -sdot(A_inv, const)

    if order < max_order:
        Y = L_aae + mdot(g_a, [g_aae])
        Z = g_aae
        V_aae = build_V(Y, Z, (n_a, n_e, n_p))

    #--- Computing derivatives ('a', 'a', 'p')

    K_aap = +mdot(f_2, [V_aa, V_p]) + 2 * mdot(f_2, [V_a, V_ap]) + mdot(
        f_3, [V_a, V_a, V_p])
    L_aap = +mdot(g_aa, [g_aa, g_p]) + 2 * mdot(g_aa, [g_a, g_ap]) + mdot(
        g_aaa, [g_a, g_a, g_p])

    #We solve A*X + const = 0
    const = sdot(f_d, L_aap) + K_aap
    g_aap = -sdot(A_inv, const)

    if order < max_order:
        Y = L_aap + mdot(g_a, [g_aap])
        Z = g_aap
        V_aap = build_V(Y, Z, (n_a, n_e, n_p))

    #--- Computing derivatives ('a', 'e', 'e')

    K_aee = +2 * mdot(f_2, [V_ae, V_e]) + mdot(f_2, [V_a, V_ee]) + mdot(
        f_3, [V_a, V_e, V_e])
    L_aee = +2 * mdot(g_aa, [g_ae, g_e]) + mdot(g_aa, [g_a, g_ee]) + mdot(
        g_aaa, [g_a, g_e, g_e])

    #We solve A*X + const = 0
    const = sdot(f_d, L_aee) + K_aee
    g_aee = -sdot(A_inv, const)

    if order < max_order:
        Y = L_aee + mdot(g_a, [g_aee])
        Z = g_aee
        V_aee = build_V(Y, Z, (n_a, n_e, n_p))

    #--- Computing derivatives ('a', 'e', 'p')
    ll = [
        mdot(f_2, [V_ae, V_p]),
        mdot(f_2, [V_ap, V_e]),
        mdot(f_2, [V_a, V_ep]),
        mdot(f_3, [V_a, V_e, V_p])
    ]
    l = [
        mdot(f_2, [V_ae, V_p]),
        mdot(f_2, [V_ap, V_e]).swapaxes(2, 3),
        mdot(f_2, [V_a, V_ep]),
        mdot(f_3, [V_a, V_e, V_p])
    ]

    K_aep = +mdot(f_2, [V_ae, V_p]) + mdot(f_2, [V_ap, V_e]).swapaxes(
        2, 3) + mdot(f_2, [V_a, V_ep]) + mdot(f_3, [V_a, V_e, V_p])
    L_aep = +mdot(g_aa, [g_ae, g_p]) + mdot(g_aa, [g_ap, g_e]).swapaxes(
        2, 3) + mdot(g_aa, [g_a, g_ep]) + mdot(g_aaa, [g_a, g_e, g_p])

    #We solve A*X + const = 0
    const = sdot(f_d, L_aep) + K_aep
    g_aep = -sdot(A_inv, const)

    if order < max_order:
        Y = L_aep + mdot(g_a, [g_aep])
        Z = g_aep
        V_aep = build_V(Y, Z, (n_a, n_e, n_p))

    #--- Computing derivatives ('a', 'p', 'p')

    K_app = +2 * mdot(f_2, [V_ap, V_p]) + mdot(f_2, [V_a, V_pp]) + mdot(
        f_3, [V_a, V_p, V_p])
    L_app = +2 * mdot(g_aa, [g_ap, g_p]) + mdot(g_aa, [g_a, g_pp]) + mdot(
        g_aaa, [g_a, g_p, g_p])

    #We solve A*X + const = 0
    const = sdot(f_d, L_app) + K_app
    g_app = -sdot(A_inv, const)

    if order < max_order:
        Y = L_app + mdot(g_a, [g_app])
        Z = g_app
        V_app = build_V(Y, Z, (n_a, n_e, n_p))

    #--- Computing derivatives ('e', 'e', 'e')

    K_eee = +3 * mdot(f_2, [V_e, V_ee]) + mdot(f_3, [V_e, V_e, V_e])
    L_eee = +3 * mdot(g_aa, [g_e, g_ee]) + mdot(g_aaa, [g_e, g_e, g_e])

    #We solve A*X + const = 0
    const = sdot(f_d, L_eee) + K_eee
    g_eee = -sdot(A_inv, const)

    if order < max_order:
        Y = L_eee + mdot(g_a, [g_eee])
        Z = g_eee
        V_eee = build_V(Y, Z, (n_a, n_e, n_p))

    #--- Computing derivatives ('e', 'e', 'p')

    K_eep = +mdot(f_2, [V_ee, V_p]) + 2 * mdot(f_2, [V_e, V_ep]) + mdot(
        f_3, [V_e, V_e, V_p])
    L_eep = +mdot(g_aa, [g_ee, g_p]) + 2 * mdot(g_aa, [g_e, g_ep]) + mdot(
        g_aaa, [g_e, g_e, g_p])

    #We solve A*X + const = 0
    const = sdot(f_d, L_eep) + K_eep
    g_eep = -sdot(A_inv, const)

    if order < max_order:
        Y = L_eep + mdot(g_a, [g_eep])
        Z = g_eep
        V_eep = build_V(Y, Z, (n_a, n_e, n_p))

    #--- Computing derivatives ('e', 'p', 'p')

    K_epp = +2 * mdot(f_2, [V_ep, V_p]) + mdot(f_2, [V_e, V_pp]) + mdot(
        f_3, [V_e, V_p, V_p])
    L_epp = +2 * mdot(g_aa, [g_ep, g_p]) + mdot(g_aa, [g_e, g_pp]) + mdot(
        g_aaa, [g_e, g_p, g_p])

    #We solve A*X + const = 0
    const = sdot(f_d, L_epp) + K_epp
    g_epp = -sdot(A_inv, const)

    if order < max_order:
        Y = L_epp + mdot(g_a, [g_epp])
        Z = g_epp
        V_epp = build_V(Y, Z, (n_a, n_e, n_p))

    #--- Computing derivatives ('p', 'p', 'p')

    K_ppp = +3 * mdot(f_2, [V_p, V_pp]) + mdot(f_3, [V_p, V_p, V_p])
    L_ppp = +3 * mdot(g_aa, [g_p, g_pp]) + mdot(g_aaa, [g_p, g_p, g_p])

    #We solve A*X + const = 0
    const = sdot(f_d, L_ppp) + K_ppp
    g_ppp = -sdot(A_inv, const)

    if order < max_order:
        Y = L_ppp + mdot(g_a, [g_ppp])
        Z = g_ppp
        V_ppp = build_V(Y, Z, (n_a, n_e, n_p))

    d.update({
        'g_aaa': g_aaa,
        'g_aae': g_aae,
        'g_aee': g_aee,
        'g_eee': g_eee,
        'g_aap': g_aap,
        'g_aep': g_aep,
        'g_eep': g_eep,
        'g_app': g_app,
        'g_epp': g_epp,
        'g_ppp': g_ppp
    })

    return d
Exemple #6
0
def perturb_solver(derivatives,
                   Sigma,
                   max_order=2,
                   derivatives_ss=None,
                   mlab=None):

    if max_order == 1:
        [f_0, f_1] = derivatives
    elif max_order == 2:
        [f_0, f_1, f_2] = derivatives
    elif max_order == 3:
        [f_0, f_1, f_2, f_3] = derivatives
    else:
        raise Exception(
            'Perturbations not implemented at order {0}'.format(max_order))
    diff = derivatives

    f = diff
    n = f[0].shape[0]  # number of variables
    s = f[1].shape[1] - 3 * n
    [n_v, n_s] = [n, s]

    f1_A = f[1][:, :n]
    f1_B = f[1][:, n:(2 * n)]
    f1_C = f[1][:, (2 * n):(3 * n)]
    f1_D = f[1][:, (3 * n):]

    ## first order
    [ev, g_x] = second_order_solver(f1_A, f1_B, f1_C)

    res = np.dot(f1_A, np.dot(g_x, g_x)) + np.dot(f1_B, g_x) + f1_C

    mm = np.dot(f1_A, g_x) + f1_B

    g_u = -np.linalg.solve(mm, f1_D)

    if max_order == 1:
        d = {'ev': ev, 'g_a': g_x, 'g_e': g_u}
        return d

    # we need it for higher order
    V_a = np.concatenate(
        [np.dot(g_x, g_x), g_x,
         np.eye(n_v), np.zeros((s, n))])
    V_e = np.concatenate(
        [np.dot(g_x, g_u), g_u,
         np.zeros((n_v, n_s)),
         np.eye(n_s)])

    # Translation

    f_1 = f[1]
    f_2 = f[2]
    f_d = f1_A
    f_a = f1_B
    g_a = g_x
    g_e = g_u
    n_a = n_v
    n_e = n_s

    # Once for all !
    A = f_a + sdot(f_d, g_a)
    B = f_d
    C = g_a
    A_inv = np.linalg.inv(A)

    ##################
    # Automatic code #
    ##################

    #----------Computing order 2

    order = 2

    #--- Computing derivatives ('a', 'a')

    K_aa = +mdot(f_2, [V_a, V_a])
    L_aa = np.zeros((n_v, n_v, n_v))

    #We need to solve the infamous sylvester equation
    #A = f_a + sdot(f_d,g_a)
    #B = f_d
    #C = g_a
    D = K_aa + sdot(f_d, L_aa)
    if mlab == None:
        g_aa = solve_sylvester(A, B, C, D)
    else:
        n_d = D.ndim - 1
        n_v = C.shape[1]
        CC = np.kron(C, C)
        DD = D.reshape(n_v, n_v**n_d)
        [err, E] = mlab.gensylv(2, A, B, C, DD, nout=2)
        g_aa = -E.reshape((n_v, n_v, n_v))  # check that - is correct

    if order < max_order:
        Y = L_aa + mdot(g_a, [g_aa]) + mdot(g_aa, [g_a, g_a])
        assert (abs(mdot(g_a, [g_aa]) - sdot(g_a, g_aa)).max() == 0)
        Z = g_aa
        V_aa = build_V(Y, Z, (n_a, n_e))

    #--- Computing derivatives ('a', 'e')

    K_ae = +mdot(f_2, [V_a, V_e])
    L_ae = +mdot(g_aa, [g_a, g_e])

    #We solve A*X + const = 0
    const = sdot(f_d, L_ae) + K_ae
    g_ae = -sdot(A_inv, const)

    if order < max_order:
        Y = L_ae + mdot(g_a, [g_ae])
        Z = g_ae
        V_ae = build_V(Y, Z, (n_a, n_e))

    #--- Computing derivatives ('e', 'e')

    K_ee = +mdot(f_2, [V_e, V_e])
    L_ee = +mdot(g_aa, [g_e, g_e])

    #We solve A*X + const = 0
    const = sdot(f_d, L_ee) + K_ee
    g_ee = -sdot(A_inv, const)

    if order < max_order:
        Y = L_ee + mdot(g_a, [g_ee])
        Z = g_ee
        V_ee = build_V(Y, Z, (n_a, n_e))

    # manual
    I = np.eye(n_v, n_v)
    M_inv = np.linalg.inv(sdot(f1_A, g_a + I) + f1_B)
    K_ss = mdot(f_2[:, :n_v, :n_v], [g_e, g_e]) + sdot(f1_A, g_ee)
    rhs = -np.tensordot(K_ss, Sigma, axes=((1, 2),
                                           (0, 1)))  #- mdot(h_2,[V_s,V_s])
    if derivatives_ss:
        f_ss = derivatives_ss[0]
        rhs -= f_ss
    g_ss = sdot(M_inv, rhs)
    ghs2 = g_ss / 2

    if max_order == 2:
        d = {
            'ev': ev,
            'g_a': g_a,
            'g_e': g_e,
            'g_aa': g_aa,
            'g_ae': g_ae,
            'g_ee': g_ee,
            'g_ss': g_ss
        }
        return d
    # /manual

    #----------Computing order 3

    order = 3

    #--- Computing derivatives ('a', 'a', 'a')
    K_aaa = +3 * mdot(f_2, [V_a, V_aa]) + mdot(f_3, [V_a, V_a, V_a])
    L_aaa = +3 * mdot(g_aa, [g_a, g_aa])

    #K_aaa =  2*( mdot(f_2,[V_aa,V_a]) ) + mdot(f_2,[V_a,V_aa]) + mdot(f_3,[V_a,V_a,V_a])
    #L_aaa =  2*( mdot(g_aa,[g_aa,g_a]) ) + mdot(g_aa,[g_a,g_aa])
    #K_aaa =  ( mdot(f_2,[V_aa,V_a]) + mdot(f_2,[V_a,V_aa]) )*3.0/2.0 + mdot(f_3,[V_a,V_a,V_a])
    #L_aaa =  ( mdot(g_aa,[g_aa,g_a]) + mdot(g_aa,[g_a,g_aa]) )*3.0/2.0

    #K_aaa = (K_aaa + K_aaa.swapaxes(3,2) + K_aaa.swapaxes(1,2) + K_aaa.swapaxes(1,2).swapaxes(2,3) + K_aaa.swapaxes(1,3) + K_aaa.swapaxes(1,3).swapaxes(2,3) )/6
    #L_aaa = (L_aaa + L_aaa.swapaxes(3,2) + L_aaa.swapaxes(1,2) + L_aaa.swapaxes(1,2).swapaxes(2,3) + L_aaa.swapaxes(1,3) + L_aaa.swapaxes(1,3).swapaxes(2,3) )/6

    #We need to solve the infamous sylvester equation
    #A = f_a + sdot(f_d,g_a)
    #B = f_d
    #C = g_a
    D = K_aaa + sdot(f_d, L_aaa)

    if mlab == None:
        g_aaa = solve_sylvester(A, B, C, D)
    # this is much much faster
    else:
        n_d = D.ndim - 1
        n_v = C.shape[1]
        CC = np.kron(np.kron(C, C), C)
        DD = D.reshape(n_v, n_v**n_d)
        [err, E] = mlab.gensylv(3, A, B, C, DD, nout=2)
        g_aaa = E.reshape((n_v, n_v, n_v, n_v))

    #res = sdot(A,g_aaa) + sdot(B, mdot(g_aaa,[C,C,C])) - D
    #print 'res : ' + str( abs(res).max() )

    if order < max_order:
        Y = L_aaa + mdot(g_a, [g_aaa]) + mdot(g_aaa, [g_a, g_a, g_a])
        Z = g_aaa
        V_aaa = build_V(Y, Z, (n_a, n_e))

    # we transform g_aaa into a symmetric multilinear form
    g_aaa = (g_aaa + g_aaa.swapaxes(3, 2) + g_aaa.swapaxes(1, 2) +
             g_aaa.swapaxes(1, 2).swapaxes(2, 3) + g_aaa.swapaxes(1, 3) +
             g_aaa.swapaxes(1, 3).swapaxes(2, 3)) / 6

    #--- Computing derivatives ('a', 'a', 'e')

    K_aae = +mdot(f_2, [V_aa, V_e]) + 2 * mdot(f_2, [V_a, V_ae]) + mdot(
        f_3, [V_a, V_a, V_e])
    L_aae = +mdot(g_aa, [g_aa, g_e]) + 2 * mdot(g_aa, [g_a, g_ae]) + mdot(
        g_aaa, [g_a, g_a, g_e])

    #We solve A*X + const = 0
    const = sdot(f_d, L_aae) + K_aae
    g_aae = -sdot(A_inv, const)

    if order < max_order:
        Y = L_aae + mdot(g_a, [g_aae])
        Z = g_aae
        V_aae = build_V(Y, Z, (n_a, n_e))

    #--- Computing derivatives ('a', 'e', 'e')

    K_aee = +2 * mdot(f_2, [V_ae, V_e]) + mdot(f_2, [V_a, V_ee]) + mdot(
        f_3, [V_a, V_e, V_e])
    L_aee = +2 * mdot(g_aa, [g_ae, g_e]) + mdot(g_aa, [g_a, g_ee]) + mdot(
        g_aaa, [g_a, g_e, g_e])

    #We solve A*X + const = 0
    const = sdot(f_d, L_aee) + K_aee
    g_aee = -sdot(A_inv, const)

    if order < max_order:
        Y = L_aee + mdot(g_a, [g_aee])
        Z = g_aee
        V_aee = build_V(Y, Z, (n_a, n_e))

    #--- Computing derivatives ('e', 'e', 'e')

    K_eee = +3 * mdot(f_2, [V_e, V_ee]) + mdot(f_3, [V_e, V_e, V_e])
    L_eee = +3 * mdot(g_aa, [g_e, g_ee]) + mdot(g_aaa, [g_e, g_e, g_e])

    #We solve A*X + const = 0
    const = sdot(f_d, L_eee) + K_eee
    g_eee = -sdot(A_inv, const)

    if order < max_order:
        Y = L_eee + mdot(g_a, [g_eee])
        Z = g_eee
        V_eee = build_V(Y, Z, (n_a, n_e))

    ####################################
    ## Compute sigma^2 correction term #
    ####################################

    # ( a s s )

    A = f_a + sdot(f_d, g_a)
    I_e = np.eye(n_e)

    Y = g_e
    Z = np.zeros((n_a, n_e))
    V_s = build_V(Y, Z, (n_a, n_e))

    Y = mdot(g_ae, [g_a, I_e])
    Z = np.zeros((n_a, n_a, n_e))
    V_as = build_V(Y, Z, (n_a, n_e))

    Y = sdot(g_a, g_ss) + g_ss + np.tensordot(g_ee, Sigma)
    Z = g_ss
    V_ss = build_V(Y, Z, (n_a, n_e))

    K_ass_1 = 2 * mdot(f_2, [V_as, V_s]) + mdot(f_3, [V_a, V_s, V_s])
    K_ass_1 = np.tensordot(K_ass_1, Sigma)

    K_ass_2 = mdot(f_2, [V_a, V_ss])

    K_ass = K_ass_1 + K_ass_2

    L_ass = mdot(g_aa, [g_a, g_ss]) + np.tensordot(
        mdot(g_aee, [g_a, I_e, I_e]), Sigma)

    D = K_ass + sdot(f_d, L_ass)

    if derivatives_ss:
        f_1ss = derivatives_ss[1]
        D += mdot(f_1ss, [V_a]) + sdot(f1_A, mdot(g_aa, [g_a, g_ss]))

    g_ass = solve_sylvester(A, B, C, D)

    # ( e s s )

    A = f_a + sdot(f_d, g_a)
    A_inv = np.linalg.inv(A)
    I_e = np.eye(n_e)

    Y = g_e
    Z = np.zeros((n_a, n_e))
    V_s = build_V(Y, Z, (n_a, n_e))

    Y = mdot(g_ae, [g_e, I_e])
    Z = np.zeros((n_a, n_e, n_e))
    V_es = build_V(Y, Z, (n_a, n_e))

    Y = sdot(g_a, g_ss) + g_ss + np.tensordot(g_ee, Sigma)
    Z = g_ss
    V_ss = build_V(Y, Z, (n_a, n_e))

    K_ess_1 = 2 * mdot(f_2, [V_es, V_s]) + mdot(f_3, [V_e, V_s, V_s])
    K_ess_1 = np.tensordot(K_ess_1, Sigma)

    K_ess_2 = mdot(f_2, [V_e, V_ss])

    K_ess = K_ess_1 + K_ess_2

    L_ess = mdot(g_aa, [g_e, g_ss]) + np.tensordot(
        mdot(g_aee, [g_e, I_e, I_e]), Sigma)
    L_ess += mdot(g_ass, [g_e])

    D = K_ess + sdot(f_d, L_ess)

    g_ess = sdot(A_inv, -D)

    if max_order == 3:
        d = {
            'ev': ev,
            'g_a': g_a,
            'g_e': g_e,
            'g_aa': g_aa,
            'g_ae': g_ae,
            'g_ee': g_ee,
            'g_aaa': g_aaa,
            'g_aae': g_aae,
            'g_aee': g_aee,
            'g_eee': g_eee,
            'g_ss': g_ss,
            'g_ass': g_ass,
            'g_ess': g_ess
        }
        return d
Exemple #7
0
def new_solver_with_p(derivatives, sizes, max_order=2):

    if max_order == 1:
        [f_0,f_1] = derivatives
    elif max_order == 2:
        [f_0,f_1,f_2] = derivatives
    elif max_order == 3:
        [f_0,f_1,f_2,f_3] = derivatives
    derivs = derivatives

    f = derivs
    #n = f[0].shape[0] # number of variables
    #s = f[1].shape[1] - 3*n
    [n_v,n_s,n_p] = sizes

    n = n_v

    f1_A = f[1][:,:n]
    f1_B = f[1][:,n:(2*n)]
    f1_C = f[1][:,(2*n):(3*n)]
    f1_D = f[1][:,(3*n):((3*n)+n_s)]
    f1_E = f[1][:,((3*n)+n_s):]

    ## first order
    [ev,g_x] = second_order_solver(f1_A,f1_B,f1_C)

    mm = np.dot(f1_A, g_x) + f1_B

    g_u = - np.linalg.solve( mm , f1_D )
    g_p = - np.linalg.solve( mm , f1_E )

    d = {
        'ev':ev,
        'g_a':g_x,
        'g_e':g_u,
        'g_p':g_p
    }

    if max_order == 1:
        return d

    # we need it for higher order

    V_x = np.concatenate( [np.dot(g_x,g_x),g_x,np.eye(n_v),np.zeros((n_s,n_v)), np.zeros((n_p,n_v))] )
    V_u = np.concatenate( [np.dot(g_x,g_u),g_u,np.zeros((n_v,n_s)),np.eye(n_s), np.zeros((n_p,n_s))] )
    V_p = np.concatenate( [np.dot(g_x,g_p),g_p,np.zeros((n_v,n_p)),np.zeros((n_s,n_p)), np.eye(n_p)] )
    V = [None, [V_x,V_u]]

    # Translation
    n_a = n_v
    n_e = n_s
    n_p = g_p.shape[1]

    f_1 = f[1]
    f_2 = f[2]
    f_d = f1_A
    f_a = f1_B
    f_h = f1_C
    f_u = f1_D
    V_a = V_x
    V_e = V_u
    g_a = g_x
    g_e = g_u

    # Once for all !
    A = f_a + sdot(f_d,g_a)
    B = f_d
    C = g_a
    A_inv = np.linalg.inv(A)

    #----------Computing order 2

    order = 2

    #--- Computing derivatives ('a', 'a')

    K_aa =  + mdot(f_2,[V_a,V_a])
    L_aa = np.zeros((n_v, n_a, n_a))

    #We need to solve the infamous sylvester equation
    #A = f_a + sdot(f_d,g_a)
    #B = f_d
    #C = g_a
    D =  K_aa + sdot(f_d,L_aa)
    g_aa = solve_sylvester(A,B,C,D)

    if order < max_order:
        Y = L_aa + mdot(g_a,[g_aa]) + mdot(g_aa,[g_a,g_a])
        Z = g_aa
        V_aa = build_V(Y,Z,(n_a,n_e,n_p))

    #--- Computing derivatives ('a', 'e')

    K_ae =  + mdot(f_2,[V_a,V_e])
    L_ae =  + mdot(g_aa,[g_a,g_e])

    #We solve A*X + const = 0
    const = sdot(f_d,L_ae) + K_ae
    g_ae = - sdot(A_inv, const)

    if order < max_order:
        Y = L_ae + mdot(g_a,[g_ae])
        Z = g_ae
        V_ae = build_V(Y,Z,(n_a,n_e,n_p))

    #--- Computing derivatives ('a', 'p')

    K_ap =  + mdot(f_2,[V_a,V_p])
    L_ap =  + mdot(g_aa,[g_a,g_p])

    #We solve A*X + const = 0
    const = sdot(f_d,L_ap) + K_ap
    g_ap = - sdot(A_inv, const)

    if order < max_order:
        Y = L_ap + mdot(g_a,[g_ap])
        Z = g_ap
        V_ap = build_V(Y,Z,(n_a,n_e,n_p))

    #--- Computing derivatives ('e', 'e')

    K_ee =  + mdot(f_2,[V_e,V_e])
    L_ee =  + mdot(g_aa,[g_e,g_e])

    #We solve A*X + const = 0
    const = sdot(f_d,L_ee) + K_ee
    g_ee = - sdot(A_inv, const)

    if order < max_order:
        Y = L_ee + mdot(g_a,[g_ee])
        Z = g_ee
        V_ee = build_V(Y,Z,(n_a,n_e,n_p))

    #--- Computing derivatives ('e', 'p')

    K_ep =  + mdot(f_2,[V_e,V_p])
    L_ep =  + mdot(g_aa,[g_e,g_p])

    #We solve A*X + const = 0
    const = sdot(f_d,L_ep) + K_ep
    g_ep = - sdot(A_inv, const)

    if order < max_order:
        Y = L_ep + mdot(g_a,[g_ep])
        Z = g_ep
        V_ep = build_V(Y,Z,(n_a,n_e,n_p))

    #--- Computing derivatives ('p', 'p')

    K_pp =  + mdot(f_2,[V_p,V_p])
    L_pp =  + mdot(g_aa,[g_p,g_p])

    #We solve A*X + const = 0
    const = sdot(f_d,L_pp) + K_pp
    g_pp = - sdot(A_inv, const)

    if order < max_order:
        Y = L_pp + mdot(g_a,[g_pp])
        Z = g_pp
        V_pp = build_V(Y,Z,(n_a,n_e,n_p))


    d.update({
        'g_aa':g_aa,
        'g_ae':g_ae,
        'g_ee':g_ee,
        'g_ap':g_ap,
        'g_ep':g_ep,
        'g_pp':g_pp
     })
    if max_order == 2:
        return d

    #----------Computing order 3

    order = 3

    #--- Computing derivatives ('a', 'a', 'a')

    K_aaa =  + 3*mdot(f_2,[V_a,V_aa]) + mdot(f_3,[V_a,V_a,V_a])
    L_aaa =  + 3*mdot(g_aa,[g_a,g_aa])

    #We need to solve the infamous sylvester equation
    #A = f_a + sdot(f_d,g_a)
    #B = f_d
    #C = g_a
    D = K_aaa + sdot(f_d,L_aaa)
    g_aaa = solve_sylvester(A,B,C,D)

    if order < max_order:
        Y = L_aaa + mdot(g_a,[g_aaa]) + mdot(g_aaa,[g_a,g_a,g_a])
        Z = g_aaa
        V_aaa = build_V(Y,Z,(n_a,n_e,n_p))

    #--- Computing derivatives ('a', 'a', 'e')

    K_aae =  + mdot(f_2,[V_aa,V_e]) + 2*mdot(f_2,[V_a,V_ae]) + mdot(f_3,[V_a,V_a,V_e])
    L_aae =  + mdot(g_aa,[g_aa,g_e]) + 2*mdot(g_aa,[g_a,g_ae]) + mdot(g_aaa,[g_a,g_a,g_e])

    #We solve A*X + const = 0
    const = sdot(f_d,L_aae) + K_aae
    g_aae = - sdot(A_inv, const)

    if order < max_order:
        Y = L_aae + mdot(g_a,[g_aae])
        Z = g_aae
        V_aae = build_V(Y,Z,(n_a,n_e,n_p))

    #--- Computing derivatives ('a', 'a', 'p')

    K_aap =  + mdot(f_2,[V_aa,V_p]) + 2*mdot(f_2,[V_a,V_ap]) + mdot(f_3,[V_a,V_a,V_p])
    L_aap =  + mdot(g_aa,[g_aa,g_p]) + 2*mdot(g_aa,[g_a,g_ap]) + mdot(g_aaa,[g_a,g_a,g_p])

    #We solve A*X + const = 0
    const = sdot(f_d,L_aap) + K_aap
    g_aap = - sdot(A_inv, const)

    if order < max_order:
        Y = L_aap + mdot(g_a,[g_aap])
        Z = g_aap
        V_aap = build_V(Y,Z,(n_a,n_e,n_p))

    #--- Computing derivatives ('a', 'e', 'e')

    K_aee =  + 2*mdot(f_2,[V_ae,V_e]) + mdot(f_2,[V_a,V_ee]) + mdot(f_3,[V_a,V_e,V_e])
    L_aee =  + 2*mdot(g_aa,[g_ae,g_e]) + mdot(g_aa,[g_a,g_ee]) + mdot(g_aaa,[g_a,g_e,g_e])

    #We solve A*X + const = 0
    const = sdot(f_d,L_aee) + K_aee
    g_aee = - sdot(A_inv, const)

    if order < max_order:
        Y = L_aee + mdot(g_a,[g_aee])
        Z = g_aee
        V_aee = build_V(Y,Z,(n_a,n_e,n_p))

    #--- Computing derivatives ('a', 'e', 'p')
    ll = [ mdot(f_2,[V_ae,V_p]) , mdot(f_2,[V_ap,V_e]), mdot(f_2,[V_a,V_ep]) , mdot(f_3,[V_a,V_e,V_p])     ]
    l = [ mdot(f_2,[V_ae,V_p]) , mdot(f_2,[V_ap,V_e]).swapaxes(2,3) , mdot(f_2,[V_a,V_ep]) , mdot(f_3,[V_a,V_e,V_p])     ]

    K_aep =  + mdot(f_2,[V_ae,V_p]) + mdot(f_2,[V_ap,V_e]).swapaxes(2,3) + mdot(f_2,[V_a,V_ep]) + mdot(f_3,[V_a,V_e,V_p])
    L_aep =  + mdot(g_aa,[g_ae,g_p]) + mdot(g_aa,[g_ap,g_e]).swapaxes(2,3) + mdot(g_aa,[g_a,g_ep]) + mdot(g_aaa,[g_a,g_e,g_p])

    #We solve A*X + const = 0
    const = sdot(f_d,L_aep) + K_aep
    g_aep = - sdot(A_inv, const)

    if order < max_order:
        Y = L_aep + mdot(g_a,[g_aep])
        Z = g_aep
        V_aep = build_V(Y,Z,(n_a,n_e,n_p))

    #--- Computing derivatives ('a', 'p', 'p')

    K_app =  + 2*mdot(f_2,[V_ap,V_p]) + mdot(f_2,[V_a,V_pp]) + mdot(f_3,[V_a,V_p,V_p])
    L_app =  + 2*mdot(g_aa,[g_ap,g_p]) + mdot(g_aa,[g_a,g_pp]) + mdot(g_aaa,[g_a,g_p,g_p])

    #We solve A*X + const = 0
    const = sdot(f_d,L_app) + K_app
    g_app = - sdot(A_inv, const)

    if order < max_order:
        Y = L_app + mdot(g_a,[g_app])
        Z = g_app
        V_app = build_V(Y,Z,(n_a,n_e,n_p))

    #--- Computing derivatives ('e', 'e', 'e')

    K_eee =  + 3*mdot(f_2,[V_e,V_ee]) + mdot(f_3,[V_e,V_e,V_e])
    L_eee =  + 3*mdot(g_aa,[g_e,g_ee]) + mdot(g_aaa,[g_e,g_e,g_e])

    #We solve A*X + const = 0
    const = sdot(f_d,L_eee) + K_eee
    g_eee = - sdot(A_inv, const)

    if order < max_order:
        Y = L_eee + mdot(g_a,[g_eee])
        Z = g_eee
        V_eee = build_V(Y,Z,(n_a,n_e,n_p))

    #--- Computing derivatives ('e', 'e', 'p')

    K_eep =  + mdot(f_2,[V_ee,V_p]) + 2*mdot(f_2,[V_e,V_ep]) + mdot(f_3,[V_e,V_e,V_p])
    L_eep =  + mdot(g_aa,[g_ee,g_p]) + 2*mdot(g_aa,[g_e,g_ep]) + mdot(g_aaa,[g_e,g_e,g_p])

    #We solve A*X + const = 0
    const = sdot(f_d,L_eep) + K_eep
    g_eep = - sdot(A_inv, const)

    if order < max_order:
        Y = L_eep + mdot(g_a,[g_eep])
        Z = g_eep
        V_eep = build_V(Y,Z,(n_a,n_e,n_p))

    #--- Computing derivatives ('e', 'p', 'p')

    K_epp =  + 2*mdot(f_2,[V_ep,V_p]) + mdot(f_2,[V_e,V_pp]) + mdot(f_3,[V_e,V_p,V_p])
    L_epp =  + 2*mdot(g_aa,[g_ep,g_p]) + mdot(g_aa,[g_e,g_pp]) + mdot(g_aaa,[g_e,g_p,g_p])

    #We solve A*X + const = 0
    const = sdot(f_d,L_epp) + K_epp
    g_epp = - sdot(A_inv, const)

    if order < max_order:
        Y = L_epp + mdot(g_a,[g_epp])
        Z = g_epp
        V_epp = build_V(Y,Z,(n_a,n_e,n_p))

    #--- Computing derivatives ('p', 'p', 'p')

    K_ppp =  + 3*mdot(f_2,[V_p,V_pp]) + mdot(f_3,[V_p,V_p,V_p])
    L_ppp =  + 3*mdot(g_aa,[g_p,g_pp]) + mdot(g_aaa,[g_p,g_p,g_p])

    #We solve A*X + const = 0
    const = sdot(f_d,L_ppp) + K_ppp
    g_ppp = - sdot(A_inv, const)

    if order < max_order:
        Y = L_ppp + mdot(g_a,[g_ppp])
        Z = g_ppp
        V_ppp = build_V(Y,Z,(n_a,n_e,n_p))


    d.update({
        'g_aaa':g_aaa,
        'g_aae':g_aae,
        'g_aee':g_aee,
        'g_eee':g_eee,
        'g_aap':g_aap,
        'g_aep':g_aep,
        'g_eep':g_eep,
        'g_app':g_app,
        'g_epp':g_epp,
        'g_ppp':g_ppp
    })

    return d
Exemple #8
0
def perturb_solver(derivatives, Sigma, max_order=2, derivatives_ss=None, mlab=None):


    if max_order == 1:
        [f_0,f_1] = derivatives
    elif max_order == 2:
        [f_0,f_1,f_2] = derivatives
    elif max_order == 3:
        [f_0,f_1,f_2,f_3] = derivatives
    else:
        raise Exception('Perturbations not implemented at order {0}'.format(max_order))
    derivs = derivatives

    f = derivs
    n = f[0].shape[0] # number of variables
    s = f[1].shape[1] - 3*n
    [n_v,n_s] = [n,s]


    f1_A = f[1][:,:n]
    f1_B = f[1][:,n:(2*n)]
    f1_C = f[1][:,(2*n):(3*n)]
    f1_D = f[1][:,(3*n):]

    ## first order
    [ev,g_x] = second_order_solver(f1_A,f1_B,f1_C)

    res = np.dot(f1_A,np.dot(g_x,g_x)) + np.dot(f1_B,g_x) + f1_C

    mm = np.dot(f1_A, g_x) + f1_B

    g_u = - np.linalg.solve( mm , f1_D )

    if max_order == 1:
        d = {'ev':ev, 'g_a': g_x, 'g_e': g_u}
        return d

    # we need it for higher order
    V_a = np.concatenate( [np.dot(g_x,g_x),g_x,np.eye(n_v),np.zeros((s,n))] )
    V_e = np.concatenate( [np.dot(g_x,g_u),g_u,np.zeros((n_v,n_s)),np.eye(n_s)] )

    # Translation

    f_1 = f[1]
    f_2 = f[2]
    f_d = f1_A
    f_a = f1_B
    g_a = g_x
    g_e = g_u
    n_a = n_v
    n_e = n_s

    # Once for all !
    A = f_a + sdot(f_d,g_a)
    B = f_d
    C = g_a
    A_inv = np.linalg.inv(A)


    ##################
    # Automatic code #
    ##################

    #----------Computing order 2

    order = 2

    #--- Computing derivatives ('a', 'a')

    K_aa =  + mdot(f_2,[V_a,V_a])
    L_aa = np.zeros( (n_v, n_v, n_v) )

    #We need to solve the infamous sylvester equation
    #A = f_a + sdot(f_d,g_a)
    #B = f_d
    #C = g_a
    D = K_aa + sdot(f_d,L_aa)
    if mlab==None:
        g_aa = solve_sylvester(A,B,C,D)
    else:
        n_d = D.ndim - 1
        n_v = C.shape[1]
        CC = np.kron(C,C)
        DD = D.reshape( n_v, n_v**n_d )
        [err,E] = mlab.gensylv(2,A,B,C,DD,nout=2)
        g_aa = - E.reshape((n_v,n_v,n_v)) # check that - is correct


    if order < max_order:
        Y = L_aa + mdot(g_a,[g_aa]) + mdot(g_aa,[g_a,g_a])
        assert( abs(mdot(g_a,[g_aa]) - sdot(g_a,g_aa)).max() == 0)
        Z = g_aa
        V_aa = build_V(Y,Z,(n_a,n_e))

    #--- Computing derivatives ('a', 'e')

    K_ae =  + mdot(f_2,[V_a,V_e])
    L_ae =  + mdot(g_aa,[g_a,g_e])

    #We solve A*X + const = 0
    const = sdot(f_d,L_ae) + K_ae
    g_ae = - sdot(A_inv, const)

    if order < max_order:
        Y = L_ae + mdot(g_a,[g_ae])
        Z = g_ae
        V_ae = build_V(Y,Z,(n_a,n_e))

    #--- Computing derivatives ('e', 'e')

    K_ee =  + mdot(f_2,[V_e,V_e])
    L_ee =  + mdot(g_aa,[g_e,g_e])

    #We solve A*X + const = 0
    const = sdot(f_d,L_ee) + K_ee
    g_ee = - sdot(A_inv, const)

    if order < max_order:
        Y = L_ee + mdot(g_a,[g_ee])
        Z = g_ee
        V_ee = build_V(Y,Z,(n_a,n_e))

    # manual
    I = np.eye(n_v,n_v)
    M_inv = np.linalg.inv( sdot(f1_A,g_a+I) + f1_B )
    K_ss = mdot(f_2[:,:n_v,:n_v],[g_e,g_e]) + sdot( f1_A, g_ee )
    rhs =  - np.tensordot( K_ss, Sigma, axes=((1,2),(0,1)) ) #- mdot(h_2,[V_s,V_s])
    if derivatives_ss:
        f_ss = derivatives_ss[0]
        rhs -= f_ss
    g_ss = sdot(M_inv,rhs)
    ghs2 = g_ss/2



    if max_order == 2:
        d = {
            'ev': ev,
            'g_a': g_a,
            'g_e': g_e,
            'g_aa': g_aa,
            'g_ae': g_ae,
            'g_ee': g_ee,
            'g_ss': g_ss
        }
        return d
    # /manual

    #----------Computing order 3

    order = 3

    #--- Computing derivatives ('a', 'a', 'a')
    K_aaa =  + 3*mdot(f_2,[V_a,V_aa]) + mdot(f_3,[V_a,V_a,V_a])
    L_aaa =  + 3*mdot(g_aa,[g_a,g_aa])

    #K_aaa =  2*( mdot(f_2,[V_aa,V_a]) ) + mdot(f_2,[V_a,V_aa]) + mdot(f_3,[V_a,V_a,V_a])
    #L_aaa =  2*( mdot(g_aa,[g_aa,g_a]) ) + mdot(g_aa,[g_a,g_aa])
    #K_aaa =  ( mdot(f_2,[V_aa,V_a]) + mdot(f_2,[V_a,V_aa]) )*3.0/2.0 + mdot(f_3,[V_a,V_a,V_a])
    #L_aaa =  ( mdot(g_aa,[g_aa,g_a]) + mdot(g_aa,[g_a,g_aa]) )*3.0/2.0


    #K_aaa = (K_aaa + K_aaa.swapaxes(3,2) + K_aaa.swapaxes(1,2) + K_aaa.swapaxes(1,2).swapaxes(2,3) + K_aaa.swapaxes(1,3) + K_aaa.swapaxes(1,3).swapaxes(2,3) )/6
    #L_aaa = (L_aaa + L_aaa.swapaxes(3,2) + L_aaa.swapaxes(1,2) + L_aaa.swapaxes(1,2).swapaxes(2,3) + L_aaa.swapaxes(1,3) + L_aaa.swapaxes(1,3).swapaxes(2,3) )/6


    #We need to solve the infamous sylvester equation
    #A = f_a + sdot(f_d,g_a)
    #B = f_d
    #C = g_a
    D = K_aaa + sdot(f_d,L_aaa)



    if mlab == None:
        g_aaa = solve_sylvester(A,B,C,D)
    # this is much much faster
    else:
        n_d = D.ndim - 1
        n_v = C.shape[1]
        CC = np.kron(np.kron(C,C),C)
        DD = D.reshape( n_v, n_v**n_d )
        [err,E] = mlab.gensylv(3,A,B,C,DD,nout=2)
        g_aaa = E.reshape((n_v,n_v,n_v,n_v))

    #res = sdot(A,g_aaa) + sdot(B, mdot(g_aaa,[C,C,C])) - D
    #print 'res : ' + str( abs(res).max() )


    if order < max_order:
        Y = L_aaa + mdot(g_a,[g_aaa]) + mdot(g_aaa,[g_a,g_a,g_a])
        Z = g_aaa
        V_aaa = build_V(Y,Z,(n_a,n_e))

    # we transform g_aaa into a symmetric multilinear form
    g_aaa = (g_aaa + g_aaa.swapaxes(3,2) + g_aaa.swapaxes(1,2) + g_aaa.swapaxes(1,2).swapaxes(2,3) + g_aaa.swapaxes(1,3) + g_aaa.swapaxes(1,3).swapaxes(2,3) )/6

    #--- Computing derivatives ('a', 'a', 'e')

    K_aae =  + mdot(f_2,[V_aa,V_e]) + 2*mdot(f_2,[V_a,V_ae]) + mdot(f_3,[V_a,V_a,V_e])
    L_aae =  + mdot(g_aa,[g_aa,g_e]) + 2*mdot(g_aa,[g_a,g_ae]) + mdot(g_aaa,[g_a,g_a,g_e])

    #We solve A*X + const = 0
    const = sdot(f_d,L_aae) + K_aae
    g_aae = - sdot(A_inv, const)

    if order < max_order:
        Y = L_aae + mdot(g_a,[g_aae])
        Z = g_aae
        V_aae = build_V(Y,Z,(n_a,n_e))

    #--- Computing derivatives ('a', 'e', 'e')

    K_aee =  + 2*mdot(f_2,[V_ae,V_e]) + mdot(f_2,[V_a,V_ee]) + mdot(f_3,[V_a,V_e,V_e])
    L_aee =  + 2*mdot(g_aa,[g_ae,g_e]) + mdot(g_aa,[g_a,g_ee]) + mdot(g_aaa,[g_a,g_e,g_e])

    #We solve A*X + const = 0
    const = sdot(f_d,L_aee) + K_aee
    g_aee = - sdot(A_inv, const)

    if order < max_order:
        Y = L_aee + mdot(g_a,[g_aee])
        Z = g_aee
        V_aee = build_V(Y,Z,(n_a,n_e))

    #--- Computing derivatives ('e', 'e', 'e')

    K_eee =  + 3*mdot(f_2,[V_e,V_ee]) + mdot(f_3,[V_e,V_e,V_e])
    L_eee =  + 3*mdot(g_aa,[g_e,g_ee]) + mdot(g_aaa,[g_e,g_e,g_e])

    #We solve A*X + const = 0
    const = sdot(f_d,L_eee) + K_eee
    g_eee = - sdot(A_inv, const)

    if order < max_order:
        Y = L_eee + mdot(g_a,[g_eee])
        Z = g_eee
        V_eee = build_V(Y,Z,(n_a,n_e))


    ####################################
    ## Compute sigma^2 correction term #
    ####################################

    # ( a s s )

    A = f_a + sdot(f_d,g_a)
    I_e = np.eye(n_e)

    Y = g_e
    Z = np.zeros((n_a,n_e))
    V_s = build_V(Y,Z,(n_a,n_e))

    Y = mdot( g_ae, [g_a, I_e] )
    Z = np.zeros((n_a,n_a,n_e))
    V_as = build_V(Y,Z,(n_a,n_e))

    Y = sdot(g_a,g_ss) + g_ss + np.tensordot(g_ee,Sigma)
    Z = g_ss
    V_ss = build_V(Y,Z,(n_a,n_e))

    K_ass_1 =  2*mdot(f_2,[V_as,V_s] ) + mdot(f_3,[V_a,V_s,V_s])
    K_ass_1 = np.tensordot(K_ass_1,Sigma)

    K_ass_2 = mdot( f_2, [V_a,V_ss] )

    K_ass = K_ass_1 + K_ass_2

    L_ass = mdot( g_aa, [g_a, g_ss]) + np.tensordot( mdot(g_aee,[g_a, I_e, I_e]), Sigma)

    D = K_ass + sdot(f_d,L_ass)

    if derivatives_ss:
        f_1ss = derivatives_ss[1]
        D += mdot( f_1ss, [V_a]) + sdot( f1_A, mdot( g_aa, [ g_a , g_ss ])  )

    g_ass = solve_sylvester( A, B, C, D)


    # ( e s s )

    A = f_a + sdot(f_d,g_a)
    A_inv = np.linalg.inv(A)
    I_e = np.eye(n_e)

    Y = g_e
    Z = np.zeros((n_a,n_e))
    V_s = build_V(Y,Z,(n_a,n_e))

    Y = mdot( g_ae, [g_e, I_e] )
    Z = np.zeros((n_a,n_e,n_e))
    V_es = build_V(Y,Z,(n_a,n_e))

    Y = sdot(g_a,g_ss) + g_ss + np.tensordot(g_ee,Sigma)
    Z = g_ss
    V_ss = build_V(Y,Z,(n_a,n_e))

    K_ess_1 =  2*mdot(f_2,[V_es,V_s] ) + mdot(f_3,[V_e,V_s,V_s])
    K_ess_1 = np.tensordot(K_ess_1,Sigma)

    K_ess_2 = mdot( f_2, [V_e,V_ss] )

    K_ess = K_ess_1 + K_ess_2

    L_ess = mdot( g_aa, [g_e, g_ss]) + np.tensordot( mdot(g_aee,[g_e, I_e, I_e]), Sigma)
    L_ess += mdot( g_ass, [g_e])

    D = K_ess + sdot(f_d,L_ess)

    g_ess = sdot( A_inv, -D)

    if max_order == 3:
        d = {'ev':ev,'g_a':g_a,'g_e':g_e, 'g_aa':g_aa, 'g_ae':g_ae, 'g_ee':g_ee,
    'g_aaa':g_aaa, 'g_aae':g_aae, 'g_aee':g_aee, 'g_eee':g_eee, 'g_ss':g_ss, 'g_ass':g_ass,'g_ess':g_ess}
        return  d
def state_perturb(f_fun, g_fun, sigma, sigma2_correction=None):
    """
    Compute the perturbation of a system in the form:
    $E_t f(s_t,x_t,s_{t+1},x_{t+1})$
    $s_t = g(s_{t-1},x_{t-1},\\epsilon_t$
    
    :param f_fun: list of derivatives of f [order0, order1, order2, ...]
    :param g_fun: list of derivatives of g [order0, order1, order2, ...]
    """
    import numpy as np
    from dolo.numeric.extern.qz import qzordered
    from numpy.linalg import solve

    approx_order = len(f_fun) - 1 # order of approximation

    [f0,f1] = f_fun[:2]

    [g0,g1] = g_fun[:2]
    n_x = f1.shape[0]           # number of controls
    n_s = f1.shape[1]/2 - n_x   # number of states
    n_e = g1.shape[1] - n_x - n_s
    n_v = n_s + n_x

    f_s = f1[:,:n_s]
    f_x = f1[:,n_s:n_s+n_x]
    f_snext = f1[:,n_v:n_v+n_s]
    f_xnext = f1[:,n_v+n_s:]

    g_s = g1[:,:n_s]
    g_x = g1[:,n_s:n_s+n_x]
    g_e = g1[:,n_v:]

    A = np.row_stack([
        np.column_stack( [ np.eye(n_s), np.zeros((n_s,n_x)) ] ),
        np.column_stack( [ -f_snext    , -f_xnext             ] )
    ])
    B = np.row_stack([
        np.column_stack( [ g_s, g_x ] ),
        np.column_stack( [ f_s, f_x ] )
    ])

    [S,T,Q,Z,eigval] = qzordered(A,B,n_s)
    
    Z11 = Z[:n_s,:n_s]
    Z12 = Z[:n_s,n_s:]
    Z21 = Z[n_s:,:n_s]
    Z22 = Z[n_s:,n_s:]
    S11 = S[:n_s,:n_s]
    T11 = T[:n_s,:n_s]

    # first order solution
    C = solve(Z11.T, Z21.T).T
    P = np.dot(solve(S11.T, Z11.T).T , solve(Z11.T, T11.T).T )
    Q = g_e

    if False:
        from numpy import dot
        test = f_s + dot(f_x,C) + dot( f_snext, g_s + dot(g_x,C) ) + dot(f_xnext, dot( C, g_s + dot(g_x,C) ) )
        print('Error: ' + str(abs(test).max()))

    if approx_order == 1:
        return [C]

    # second order solution
    from dolo.numeric.tensor import sdot, mdot
    from numpy import dot
    from dolo.numeric.matrix_equations import solve_sylvester

    f2 = f_fun[2]
    g2 = g_fun[2]
    g_ss = g2[:,:n_s,:n_s]
    g_sx = g2[:,:n_s,n_s:n_v]
    g_xx = g2[:,n_s:n_v,n_s:n_v]

    X_s = C

    V1_3 = g_s + dot(g_x,X_s)
    V1 = np.row_stack([
        np.eye(n_s),
        X_s,
        V1_3,
        dot( X_s, V1_3 )
    ])

    K2 = g_ss + 2 * sdot(g_sx,X_s) + mdot(g_xx,[X_s,X_s])
    #L2 =
    A = f_x + dot( f_snext + dot(f_xnext,X_s), g_x)
    B = f_xnext
    C = V1_3
    D = mdot(f2,[V1,V1]) + sdot(f_snext + dot(f_xnext,X_s),K2)
    
    X_ss = solve_sylvester(A,B,C,D)

#    test = sdot( A, X_ss ) + sdot( B,  mdot(X_ss,[V1_3,V1_3]) ) + D

    
    if not sigma == None:
        g_ee = g2[:,n_v:,n_v:]

        v = np.row_stack([
            g_e,
            dot(X_s,g_e)
        ])

        K_tt = mdot( f2[:,n_v:,n_v:], [v,v] )
        K_tt += sdot( f_snext + dot(f_xnext,X_s) , g_ee )
        K_tt += mdot( sdot( f_xnext, X_ss), [g_e, g_e] )
        K_tt = np.tensordot( K_tt, sigma, axes=((1,2),(0,1)))

        if sigma2_correction is not None:
            K_tt += sdot( f_snext + dot(f_xnext,X_s) , sigma2_correction[0] )

        L_tt = f_x  + dot(f_snext, g_x) + dot(f_xnext, dot(X_s, g_x) + np.eye(n_x) )
        from numpy.linalg import det
        X_tt = solve( L_tt, - K_tt)

    if approx_order == 2:
        if sigma == None:
            return [X_s,X_ss]  # here, we don't approximate the law of motion of the states
        else:
            return [[X_s,X_ss],[X_tt]]  # here, we don't approximate the law of motion of the states

    # third order solution

    f3 = f_fun[3]
    g3 = g_fun[3]
    g_sss = g3[:,:n_s,:n_s,:n_s]
    g_ssx = g3[:,:n_s,:n_s,n_s:n_v]
    g_sxx = g3[:,:n_s,n_s:n_v,n_s:n_v]
    g_xxx = g3[:,n_s:n_v,n_s:n_v,n_s:n_v]

    V2_3 = K2 + sdot(g_x,X_ss)
    V2 = np.row_stack([
        np.zeros( (n_s,n_s,n_s) ),
        X_ss,
        V2_3,
        dot( X_s, V2_3 ) + mdot(X_ss,[V1_3,V1_3])
    ])

    K3 = g_sss + 3*sdot(g_ssx,X_s) + 3*mdot(g_sxx,[X_s,X_s]) + 2*sdot(g_sx,X_ss)
    K3 += 3*mdot( g_xx,[X_ss,X_s] ) + mdot(g_xxx,[X_s,X_s,X_s])
    L3 = 3*mdot(X_ss,[V1_3,V2_3])

    # A = f_x + dot( f_snext + dot(f_xnext,X_s), g_x) # same as before
    # B = f_xnext # same
    # C = V1_3 # same
    D = mdot(f3,[V1,V1,V1]) + 3*mdot(f2,[ V2,V1 ]) + sdot(f_snext + dot(f_xnext,X_s),K3)
    D += sdot( f_xnext, L3 )

    X_sss = solve_sylvester(A,B,C,D)

    # now doing sigma correction with sigma replaced by l in the subscripts

    if not sigma is None:
        g_se= g2[:,:n_s,n_v:]
        g_xe= g2[:,n_s:n_v,n_v:]

        g_see= g3[:,:n_s,n_v:,n_v:]
        g_xee= g3[:,n_s:n_v,n_v:,n_v:]


        W_l = np.row_stack([
            g_e,
            dot(X_s,g_e)
        ])

        I_e = np.eye(n_e)

        V_sl = g_se + mdot( g_xe, [X_s, np.eye(n_e)])

        W_sl = np.row_stack([
            V_sl,
            mdot( X_ss, [ V1_3, g_e ] ) + sdot( X_s, V_sl)
        ])

        K_ee = mdot(f3[:,:,n_v:,n_v:], [V1, W_l, W_l ])
        K_ee += 2 * mdot( f2[:,n_v:,n_v:], [W_sl, W_l])

        # stochastic part of W_ll

        SW_ll = np.row_stack([
            g_ee,
            mdot(X_ss, [g_e, g_e]) + sdot(X_s, g_ee)
        ])

        DW_ll = np.concatenate([
            X_tt,
            dot(g_x, X_tt),
            dot(X_s, sdot(g_x,X_tt )) + X_tt
        ])

        K_ee += mdot( f2[:,:,n_v:], [V1, SW_ll])

        K_ = np.tensordot(K_ee, sigma, axes=((2,3),(0,1)))

        K_ += mdot(f2[:,:,n_s:], [V1, DW_ll])

        def E(vec):
            n = len(vec.shape)
            return np.tensordot(vec,sigma,axes=((n-2,n-1),(0,1)))

        L = sdot(g_sx,X_tt) + mdot(g_xx,[X_s,X_tt])

        L += E(g_see + mdot(g_xee,[X_s,I_e,I_e]) )

        M = E( mdot(X_sss,[V1_3, g_e, g_e]) + 2*mdot(X_ss, [V_sl,g_e]) )
        M += mdot( X_ss, [V1_3, E( g_ee ) + sdot(g_x, X_tt)] )


        A = f_x + dot( f_snext + dot(f_xnext,X_s), g_x) # same as before
        B = f_xnext # same
        C = V1_3 # same
        D = K_ + dot( f_snext + dot(f_xnext,X_s), L) + dot( f_xnext, M )

        if sigma2_correction is not None:
            g_sl = sigma2_correction[1][:,:n_s]
            g_xl = sigma2_correction[1][:,n_s:(n_s+n_x)]
            D += dot( f_snext + dot(f_xnext,X_s), g_sl + dot(g_xl,X_s) )   # constant

        X_stt = solve_sylvester(A,B,C,D)

    if approx_order == 3:
        if sigma is None:
            return [X_s,X_ss,X_sss]
        else:
            return [[X_s,X_ss,X_sss],[X_tt, X_stt]]
def state_perturb(f_fun, g_fun, sigma, sigma2_correction=None, verbose=True):
    """Computes a Taylor approximation of decision rules, given the supplied derivatives.

    The original system is assumed to be in the the form:

    .. math::

        E_t f(s_t,x_t,s_{t+1},x_{t+1})

        s_t = g(s_{t-1},x_{t-1}, \\lambda \\epsilon_t)

    where :math:`\\lambda` is a scalar scaling down the risk.  the solution is a function :math:`\\varphi` such that:

    .. math::

        x_t = \\varphi ( s_t, \\sigma )

    The user supplies, a list of derivatives of f and g.

    :param f_fun: list of derivatives of f [order0, order1, order2, ...]
    :param g_fun: list of derivatives of g [order0, order1, order2, ...]
    :param sigma: covariance matrix of :math:`\\epsilon_t`
    :param sigma2_correction: (optional) first and second derivatives of g w.r.t. sigma if :math:`g` explicitely depends
        :math:`sigma`


    Assuming :math:`s_t` ,  :math:`x_t` and :math:`\\epsilon_t` are vectors of size
    :math:`n_s`, :math:`n_x`  and :math:`n_x`  respectively.
    In general the derivative of order :math:`i` of :math:`f`  is a multimensional array of size :math:`n_x \\times (N, ..., N)`
    with :math:`N=2(n_s+n_x)` repeated :math:`i` times (possibly 0).
    Similarly the derivative of order :math:`i` of :math:`g`  is a multidimensional array of size :math:`n_s \\times (M, ..., M)`
    with :math:`M=n_s+n_x+n_2` repeated :math:`i` times (possibly 0).

    

    """

    import numpy as np
    from numpy.linalg import solve

    approx_order = len(f_fun) - 1  # order of approximation

    [f0, f1] = f_fun[:2]

    [g0, g1] = g_fun[:2]
    n_x = f1.shape[0]  # number of controls
    n_s = f1.shape[1] / 2 - n_x  # number of states
    n_e = g1.shape[1] - n_x - n_s
    n_v = n_s + n_x

    f_s = f1[:, :n_s]
    f_x = f1[:, n_s:n_s + n_x]
    f_snext = f1[:, n_v:n_v + n_s]
    f_xnext = f1[:, n_v + n_s:]

    g_s = g1[:, :n_s]
    g_x = g1[:, n_s:n_s + n_x]
    g_e = g1[:, n_v:]

    A = np.row_stack([
        np.column_stack([np.eye(n_s), np.zeros((n_s, n_x))]),
        np.column_stack([-f_snext, -f_xnext])
    ])
    B = np.row_stack(
        [np.column_stack([g_s, g_x]),
         np.column_stack([f_s, f_x])])

    from dolo.numeric.extern.qz import qzordered
    [S, T, Q, Z, eigval] = qzordered(A, B, n_s)

    # Check Blanchard=Kahn conditions
    n_big_one = sum(eigval > 1.0)
    n_expected = n_x
    if verbose:
        print("There are {} eigenvalues greater than 1. Expected: {}.".format(
            n_big_one, n_x))

    if n_big_one != n_expected:
        raise Exception(
            "There should be exactly {} eigenvalues greater than one. Not {}.".
            format(n_x, n_big_one))

    Q = Q.real  # is it really necessary ?
    Z = Z.real

    Z11 = Z[:n_s, :n_s]
    Z12 = Z[:n_s, n_s:]
    Z21 = Z[n_s:, :n_s]
    Z22 = Z[n_s:, n_s:]
    S11 = S[:n_s, :n_s]
    T11 = T[:n_s, :n_s]

    # first order solution
    C = solve(Z11.T, Z21.T).T
    P = np.dot(solve(S11.T, Z11.T).T, solve(Z11.T, T11.T).T)
    Q = g_e

    if False:
        from numpy import dot
        test = f_s + dot(f_x, C) + dot(f_snext, g_s + dot(g_x, C)) + dot(
            f_xnext, dot(C, g_s + dot(g_x, C)))
        print('Error: ' + str(abs(test).max()))

    if approx_order == 1:
        return [C]

    # second order solution
    from dolo.numeric.tensor import sdot, mdot
    from numpy import dot
    from dolo.numeric.matrix_equations import solve_sylvester

    f2 = f_fun[2]
    g2 = g_fun[2]
    g_ss = g2[:, :n_s, :n_s]
    g_sx = g2[:, :n_s, n_s:n_v]
    g_xx = g2[:, n_s:n_v, n_s:n_v]

    X_s = C

    V1_3 = g_s + dot(g_x, X_s)
    V1 = np.row_stack([np.eye(n_s), X_s, V1_3, dot(X_s, V1_3)])

    K2 = g_ss + 2 * sdot(g_sx, X_s) + mdot(g_xx, [X_s, X_s])
    #L2 =
    A = f_x + dot(f_snext + dot(f_xnext, X_s), g_x)
    B = f_xnext
    C = V1_3
    D = mdot(f2, [V1, V1]) + sdot(f_snext + dot(f_xnext, X_s), K2)

    X_ss = solve_sylvester(A, B, C, D)

    #    test = sdot( A, X_ss ) + sdot( B,  mdot(X_ss,[V1_3,V1_3]) ) + D

    if not sigma == None:
        g_ee = g2[:, n_v:, n_v:]

        v = np.row_stack([g_e, dot(X_s, g_e)])

        K_tt = mdot(f2[:, n_v:, n_v:], [v, v])
        K_tt += sdot(f_snext + dot(f_xnext, X_s), g_ee)
        K_tt += mdot(sdot(f_xnext, X_ss), [g_e, g_e])
        K_tt = np.tensordot(K_tt, sigma, axes=((1, 2), (0, 1)))

        if sigma2_correction is not None:
            K_tt += sdot(f_snext + dot(f_xnext, X_s), sigma2_correction[0])

        L_tt = f_x + dot(f_snext, g_x) + dot(f_xnext,
                                             dot(X_s, g_x) + np.eye(n_x))
        X_tt = solve(L_tt, -K_tt)

    if approx_order == 2:
        if sigma == None:
            return [
                X_s, X_ss
            ]  # here, we don't approximate the law of motion of the states
        else:
            return [[X_s, X_ss], [
                X_tt
            ]]  # here, we don't approximate the law of motion of the states

    # third order solution

    f3 = f_fun[3]
    g3 = g_fun[3]
    g_sss = g3[:, :n_s, :n_s, :n_s]
    g_ssx = g3[:, :n_s, :n_s, n_s:n_v]
    g_sxx = g3[:, :n_s, n_s:n_v, n_s:n_v]
    g_xxx = g3[:, n_s:n_v, n_s:n_v, n_s:n_v]

    V2_3 = K2 + sdot(g_x, X_ss)
    V2 = np.row_stack([
        np.zeros((n_s, n_s, n_s)), X_ss, V2_3,
        dot(X_s, V2_3) + mdot(X_ss, [V1_3, V1_3])
    ])

    K3 = g_sss + 3 * sdot(g_ssx, X_s) + 3 * mdot(g_sxx, [X_s, X_s]) + 2 * sdot(
        g_sx, X_ss)
    K3 += 3 * mdot(g_xx, [X_ss, X_s]) + mdot(g_xxx, [X_s, X_s, X_s])
    L3 = 3 * mdot(X_ss, [V1_3, V2_3])

    # A = f_x + dot( f_snext + dot(f_xnext,X_s), g_x) # same as before
    # B = f_xnext # same
    # C = V1_3 # same
    D = mdot(f3, [V1, V1, V1]) + 3 * mdot(f2, [V2, V1]) + sdot(
        f_snext + dot(f_xnext, X_s), K3)
    D += sdot(f_xnext, L3)

    X_sss = solve_sylvester(A, B, C, D)

    # now doing sigma correction with sigma replaced by l in the subscripts

    if not sigma is None:
        g_se = g2[:, :n_s, n_v:]
        g_xe = g2[:, n_s:n_v, n_v:]

        g_see = g3[:, :n_s, n_v:, n_v:]
        g_xee = g3[:, n_s:n_v, n_v:, n_v:]

        W_l = np.row_stack([g_e, dot(X_s, g_e)])

        I_e = np.eye(n_e)

        V_sl = g_se + mdot(g_xe, [X_s, np.eye(n_e)])

        W_sl = np.row_stack([V_sl, mdot(X_ss, [V1_3, g_e]) + sdot(X_s, V_sl)])

        K_ee = mdot(f3[:, :, n_v:, n_v:], [V1, W_l, W_l])
        K_ee += 2 * mdot(f2[:, n_v:, n_v:], [W_sl, W_l])

        # stochastic part of W_ll

        SW_ll = np.row_stack([g_ee, mdot(X_ss, [g_e, g_e]) + sdot(X_s, g_ee)])

        DW_ll = np.concatenate(
            [X_tt, dot(g_x, X_tt),
             dot(X_s, sdot(g_x, X_tt)) + X_tt])

        K_ee += mdot(f2[:, :, n_v:], [V1, SW_ll])

        K_ = np.tensordot(K_ee, sigma, axes=((2, 3), (0, 1)))

        K_ += mdot(f2[:, :, n_s:], [V1, DW_ll])

        def E(vec):
            n = len(vec.shape)
            return np.tensordot(vec, sigma, axes=((n - 2, n - 1), (0, 1)))

        L = sdot(g_sx, X_tt) + mdot(g_xx, [X_s, X_tt])

        L += E(g_see + mdot(g_xee, [X_s, I_e, I_e]))

        M = E(mdot(X_sss, [V1_3, g_e, g_e]) + 2 * mdot(X_ss, [V_sl, g_e]))
        M += mdot(X_ss, [V1_3, E(g_ee) + sdot(g_x, X_tt)])

        A = f_x + dot(f_snext + dot(f_xnext, X_s), g_x)  # same as before
        B = f_xnext  # same
        C = V1_3  # same
        D = K_ + dot(f_snext + dot(f_xnext, X_s), L) + dot(f_xnext, M)

        if sigma2_correction is not None:
            g_sl = sigma2_correction[1][:, :n_s]
            g_xl = sigma2_correction[1][:, n_s:(n_s + n_x)]
            D += dot(f_snext + dot(f_xnext, X_s),
                     g_sl + dot(g_xl, X_s))  # constant

        X_stt = solve_sylvester(A, B, C, D)

    if approx_order == 3:
        if sigma is None:
            return [X_s, X_ss, X_sss]
        else:
            return [[X_s, X_ss, X_sss], [X_tt, X_stt]]
def state_perturb(problem: PerturbationProblem, verbose=True):
    """Computes a Taylor approximation of decision rules, given the supplied derivatives.

    The original system is assumed to be in the the form:

    .. math::

        E_t f(s_t,x_t,s_{t+1},x_{t+1})

        s_t = g(s_{t-1},x_{t-1}, \\lambda \\epsilon_t)

    where :math:`\\lambda` is a scalar scaling down the risk.  the solution is a function :math:`\\varphi` such that:

    .. math::

        x_t = \\varphi ( s_t, \\sigma )

    The user supplies, a list of derivatives of f and g.

    :param f_fun: list of derivatives of f [order0, order1, order2, ...]
    :param g_fun: list of derivatives of g [order0, order1, order2, ...]
    :param sigma: covariance matrix of :math:`\\epsilon_t`


    Assuming :math:`s_t` ,  :math:`x_t` and :math:`\\epsilon_t` are vectors of size
    :math:`n_s`, :math:`n_x`  and :math:`n_x`  respectively.
    In general the derivative of order :math:`i` of :math:`f`  is a multimensional array of size :math:`n_x \\times (N, ..., N)`
    with :math:`N=2(n_s+n_x)` repeated :math:`i` times (possibly 0).
    Similarly the derivative of order :math:`i` of :math:`g`  is a multidimensional array of size :math:`n_s \\times (M, ..., M)`
    with :math:`M=n_s+n_x+n_2` repeated :math:`i` times (possibly 0).
    """

    import numpy as np
    from numpy.linalg import solve

    approx_order = problem.order  # order of approximation

    [f0, f1] = problem.f[:2]
    [g0, g1] = problem.g[:2]
    sigma = problem.sigma

    n_x = f1.shape[0]           # number of controls
    n_s = f1.shape[1]//2 - n_x   # number of states
    n_e = g1.shape[1] - n_x - n_s
    n_v = n_s + n_x

    f_s = f1[:, :n_s]
    f_x = f1[:, n_s:n_s+n_x]
    f_snext = f1[:, n_v:n_v+n_s]
    f_xnext = f1[:, n_v+n_s:]

    g_s = g1[:, :n_s]
    g_x = g1[:, n_s:n_s+n_x]
    g_e = g1[:, n_v:]

    A = np.row_stack([
        np.column_stack([np.eye(n_s), np.zeros((n_s, n_x))]),
        np.column_stack([-f_snext    , -f_xnext           ])
    ])
    B = np.row_stack([
        np.column_stack([g_s, g_x]),
        np.column_stack([f_s, f_x])
    ])

    [S, T, Q, Z, eigval] = qzordered(A, B, 1.0-1e-8)

    Q = Q.real  # is it really necessary ?
    Z = Z.real

    diag_S = np.diag(S)
    diag_T = np.diag(T)

    tol_geneigvals = 1e-10

    try:
        ok = sum((abs(diag_S) < tol_geneigvals) *
                 (abs(diag_T) < tol_geneigvals)) == 0
        assert(ok)
    except Exception as e:
        raise GeneralizedEigenvaluesError(diag_S=diag_S, diag_T=diag_T)

    if max(eigval[:n_s]) >= 1 and min(eigval[n_s:]) < 1:
        # BK conditions are met
        pass
    else:
        eigval_s = sorted(eigval, reverse=True)
        ev_a = eigval_s[n_s-1]
        ev_b = eigval_s[n_s]
        cutoff = (ev_a - ev_b)/2
        if not ev_a > ev_b:
            raise GeneralizedEigenvaluesSelectionError(
                    A=A, B=B, eigval=eigval, cutoff=cutoff,
                    diag_S=diag_S, diag_T=diag_T, n_states=n_s
                )
        import warnings
        if cutoff > 1:
            warnings.warn("Solution is not convergent.")
        else:
            warnings.warn("There are multiple convergent solutions. The one with the smaller eigenvalues was selected.")
        [S, T, Q, Z, eigval] = qzordered(A, B, cutoff)

    Z11 = Z[:n_s, :n_s]
    # Z12 = Z[:n_s, n_s:]
    Z21 = Z[n_s:, :n_s]
    # Z22 = Z[n_s:, n_s:]
    S11 = S[:n_s, :n_s]
    T11 = T[:n_s, :n_s]

    # first order solution
    C = solve(Z11.T, Z21.T).T
    P = np.dot(solve(S11.T, Z11.T).T, solve(Z11.T, T11.T).T)
    Q = g_e

    # if False:
    #     from numpy import dot
    #     test = f_s + f_x @ C + f_snext @ (g_s + g_x @ C) + f_xnext @ C @ (g_s + g_x @ C)
    #     print('Error: ' + str(abs(test).max()))

    if approx_order == 1:
        return [C]

    # second order solution
    from dolo.numeric.tensor import sdot, mdot
    from numpy import dot
    from dolo.numeric.matrix_equations import solve_sylvester

    f2 = problem.f[2]
    g2 = problem.g[2]
    g_ss = g2[:, :n_s, :n_s]
    g_sx = g2[:, :n_s, n_s:n_v]
    g_xx = g2[:, n_s:n_v, n_s:n_v]

    X_s = C

    V1_3 = g_s + dot(g_x, X_s)
    V1 = np.row_stack([
        np.eye(n_s),
        X_s,
        V1_3,
        X_s @ V1_3
    ])

    K2 = g_ss + 2 * sdot(g_sx, X_s) + mdot(g_xx, X_s, X_s)
    A = f_x + dot(f_snext + dot(f_xnext, X_s), g_x)
    B = f_xnext
    C = V1_3
    D = mdot(f2, V1, V1) + sdot(f_snext + dot(f_xnext, X_s), K2)

    X_ss = solve_sylvester(A, B, C, D)

#    test = sdot( A, X_ss ) + sdot( B,  mdot(X_ss,V1_3,V1_3) ) + D

    g_ee = g2[:, n_v:, n_v:]

    v = np.row_stack([
        g_e,
        dot(X_s, g_e)
    ])

    K_tt = mdot(f2[:, n_v:, n_v:], v, v)
    K_tt += sdot(f_snext + dot(f_xnext, X_s), g_ee)
    K_tt += mdot(sdot(f_xnext, X_ss), g_e, g_e)
    K_tt = np.tensordot(K_tt, sigma, axes=((1, 2), (0, 1)))

    L_tt = f_x + dot(f_snext, g_x) + dot(f_xnext, dot(X_s, g_x) + np.eye(n_x))
    X_tt = solve(L_tt, - K_tt)

    if approx_order == 2:
        return [[X_s, X_ss], [X_tt]]

    # third order solution

    f3 = problem.f[3]
    g3 = problem.g[3]
    g_sss = g3[:, :n_s, :n_s, :n_s]
    g_ssx = g3[:, :n_s, :n_s, n_s:n_v]
    g_sxx = g3[:, :n_s, n_s:n_v, n_s:n_v]
    g_xxx = g3[:, n_s:n_v, n_s:n_v, n_s:n_v]

    V2_3 = K2 + sdot(g_x, X_ss)
    V2 = np.row_stack([
        np.zeros((n_s, n_s, n_s)),
        X_ss,
        V2_3,
        dot(X_s, V2_3) + mdot(X_ss, V1_3, V1_3)
    ])

    K3 = g_sss + 3*sdot(g_ssx, X_s) + 3*mdot(g_sxx, X_s, X_s) + 2*sdot(g_sx, X_ss)
    K3 += 3*mdot(g_xx, X_ss, X_s) + mdot(g_xxx, X_s, X_s, X_s)
    L3 = 3*mdot(X_ss, V1_3, V2_3)

    # A = f_x + dot( f_snext + dot(f_xnext,X_s), g_x) # same as before
    # B = f_xnext # same
    # C = V1_3 # same
    D = mdot(f3, V1, V1, V1) + 3*mdot(f2, V2, V1) + sdot(f_snext + dot(f_xnext, X_s), K3)
    D += sdot(f_xnext, L3)

    X_sss = solve_sylvester(A, B, C, D)

    # now doing sigma correction with sigma replaced by l in the subscripts

    g_se = g2[:, :n_s, n_v:]
    g_xe = g2[:, n_s:n_v, n_v:]

    g_see = g3[:, :n_s, n_v:, n_v:]
    g_xee = g3[:, n_s:n_v, n_v:, n_v:]

    W_l = np.row_stack([
        g_e,
        dot(X_s, g_e)
    ])

    I_e = np.eye(n_e)

    V_sl = g_se + mdot(g_xe, X_s, np.eye(n_e))

    W_sl = np.row_stack([
        V_sl,
        mdot(X_ss, V1_3, g_e) + sdot(X_s, V_sl)
    ])

    K_ee = mdot(f3[:, :, n_v:, n_v:], V1, W_l, W_l)
    K_ee += 2 * mdot(f2[:, n_v:, n_v:], W_sl, W_l)

    # stochastic part of W_ll

    SW_ll = np.row_stack([
        g_ee,
        mdot(X_ss, g_e, g_e) + sdot(X_s, g_ee)
    ])

    DW_ll = np.concatenate([
        X_tt,
        dot(g_x, X_tt),
        dot(X_s, sdot(g_x, X_tt)) + X_tt
    ])

    K_ee += mdot(f2[:, :, n_v:], V1, SW_ll)

    K_ = np.tensordot(K_ee, sigma, axes=((2,3), (0,1)))

    K_ += mdot(f2[:, :, n_s:], V1, DW_ll)

    def E(vec):
        n = len(vec.shape)
        return np.tensordot(vec, sigma, axes=((n-2, n-1), (0, 1)))

    L = sdot(g_sx, X_tt) + mdot(g_xx, X_s, X_tt)

    L += E(g_see + mdot(g_xee, X_s, I_e, I_e))

    M = E(mdot(X_sss, V1_3, g_e, g_e) + 2*mdot(X_ss, V_sl, g_e))
    M += mdot(X_ss, V1_3, E(g_ee) + sdot(g_x, X_tt))

    A = f_x + dot(f_snext + dot(f_xnext, X_s), g_x)  # same as before
    B = f_xnext  # same
    C = V1_3     # same
    D = K_ + dot(f_snext + dot(f_xnext, X_s), L) + dot(f_xnext, M)

    X_stt = solve_sylvester(A, B, C, D)

    if approx_order == 3:
        # if sigma is None:
        #     return [X_s,X_ss,X_sss]
        # else:
        #     return [[X_s,X_ss,X_sss],[X_tt, X_stt]]
        return [[X_s, X_ss, X_sss], [X_tt, X_stt]]