Ejemplo n.º 1
0
 def test_max_sigmaH_prior(self):
     D = 3
     Thrf = 25.
     dt = .5
     TT, m_h = getCanoHRF(Thrf, dt)
     m_h = m_h[:D]
     m_H = np.array(m_h).astype(np.float64)
     Sigma_H = np.ones((D, D), dtype=np.float64)
     order = 2
     D2 = vt.buildFiniteDiffMatrix(order, D)
     R = np.dot(D2, D2) / pow(dt, 2*order)
     gamma_h = 1000
     sigmaH = vt.maximization_sigmaH_prior(D, Sigma_H, R, m_H, gamma_h)
Ejemplo n.º 2
0
def Main_vbjde_Extension_constrained_stable(graph, Y, Onsets, Thrf, K, TR, beta,
                                            dt, scale=1, estimateSigmaH=True,
                                            sigmaH=0.05, NitMax=-1,
                                            NitMin=1, estimateBeta=True,
                                            PLOT=False, contrasts=[],
                                            computeContrast=False,
                                            gamma_h=0):
    """ Version modified by Lofti from Christine's version """
    logger.info(
        "Fast EM with C extension started ... Here is the stable version !")

    np.random.seed(6537546)

    # Initialize parameters
    S = 100
    if NitMax < 0:
        NitMax = 100
    gamma = 7.5  # 7.5
    gradientStep = 0.003
    MaxItGrad = 200
    Thresh = 1e-5

    # Initialize sizes vectors
    D = np.int(np.ceil(Thrf / dt)) + 1
    M = len(Onsets)
    N = Y.shape[0]
    J = Y.shape[1]
    l = np.int(np.sqrt(J))
    condition_names = []

    # Neighbours
    maxNeighbours = max([len(nl) for nl in graph])
    neighboursIndexes = np.zeros((J, maxNeighbours), dtype=np.int32)
    neighboursIndexes -= 1
    for i in xrange(J):
        neighboursIndexes[i, :len(graph[i])] = graph[i]
    # Conditions
    X = OrderedDict([])
    for condition, Ons in Onsets.iteritems():
        X[condition] = vt.compute_mat_X_2(N, TR, D, dt, Ons)
        condition_names += [condition]
    XX = np.zeros((M, N, D), dtype=np.int32)
    nc = 0
    for condition, Ons in Onsets.iteritems():
        XX[nc, :, :] = X[condition]
        nc += 1
    # Covariance matrix
    order = 2
    D2 = vt.buildFiniteDiffMatrix(order, D)
    R = np.dot(D2, D2) / pow(dt, 2 * order)
    invR = np.linalg.inv(R)
    Det_invR = np.linalg.det(invR)

    Gamma = np.identity(N)
    Det_Gamma = np.linalg.det(Gamma)

    Crit_H = 1
    Crit_Z = 1
    Crit_A = 1
    Crit_AH = 1
    AH = np.zeros((J, M, D), dtype=np.float64)
    AH1 = np.zeros((J, M, D), dtype=np.float64)
    Crit_FreeEnergy = 1
    cTime = []
    cA = []
    cH = []
    cZ = []
    cAH = []

    CONTRAST = np.zeros((J, len(contrasts)), dtype=np.float64)
    CONTRASTVAR = np.zeros((J, len(contrasts)), dtype=np.float64)
    Q_barnCond = np.zeros((M, M, D, D), dtype=np.float64)
    XGamma = np.zeros((M, D, N), dtype=np.float64)
    m1 = 0
    for k1 in X:  # Loop over the M conditions
        m2 = 0
        for k2 in X:
            Q_barnCond[m1, m2, :, :] = np.dot(
                np.dot(X[k1].transpose(), Gamma), X[k2])
            m2 += 1
        XGamma[m1, :, :] = np.dot(X[k1].transpose(), Gamma)
        m1 += 1

    sigma_epsilone = np.ones(J)
    logger.info(
        "Labels are initialized by setting active probabilities to ones ...")
    q_Z = np.zeros((M, K, J), dtype=np.float64)
    q_Z[:, 1, :] = 1
    q_Z1 = np.zeros((M, K, J), dtype=np.float64)
    Z_tilde = q_Z.copy()

    TT, m_h = getCanoHRF(Thrf, dt)  # TODO: check
    m_h = m_h[:D]
    m_H = np.array(m_h).astype(np.float64)
    m_H1 = np.array(m_h)
    sigmaH1 = sigmaH
    Sigma_H = np.ones((D, D), dtype=np.float64)

    Beta = beta * np.ones((M), dtype=np.float64)
    P = vt.PolyMat(N, 4, TR)
    L = vt.polyFit(Y, TR, 4, P)
    PL = np.dot(P, L)
    y_tilde = Y - PL
    Ndrift = L.shape[0]

    sigma_M = np.ones((M, K), dtype=np.float64)
    sigma_M[:, 0] = 0.5
    sigma_M[:, 1] = 0.6
    mu_M = np.zeros((M, K), dtype=np.float64)
    for k in xrange(1, K):
        mu_M[:, k] = 1  # InitMean
    Sigma_A = np.zeros((M, M, J), np.float64)
    for j in xrange(0, J):
        Sigma_A[:, :, j] = 0.01 * np.identity(M)
    m_A = np.zeros((J, M), dtype=np.float64)
    m_A1 = np.zeros((J, M), dtype=np.float64)
    for j in xrange(0, J):
        for m in xrange(0, M):
            for k in xrange(0, K):
                m_A[j, m] += np.random.normal(mu_M[m, k],
                                              np.sqrt(sigma_M[m, k])) * q_Z[m, k, j]
    m_A1 = m_A

    t1 = time.time()

    ##########################################################################
    # VBJDE num. iter. minimum

    ni = 0

    while ((ni < NitMin + 1) or ((Crit_AH > Thresh) and (ni < NitMax))):

        logger.info("------------------------------ Iteration n° " +
                    str(ni + 1) + " ------------------------------")

        #####################
        # EXPECTATION
        #####################

        # A
        logger.info("E A step ...")
        UtilsC.expectation_A(q_Z, mu_M, sigma_M, PL, sigma_epsilone, Gamma,
                             Sigma_H, Y, y_tilde, m_A, m_H, Sigma_A, XX.astype(np.int32), J, D, M, N, K)

        # crit. A
        DIFF = np.reshape(m_A - m_A1, (M * J))
        Crit_A = (
            np.linalg.norm(DIFF) / np.linalg.norm(np.reshape(m_A1, (M * J)))) ** 2
        cA += [Crit_A]
        m_A1[:, :] = m_A[:, :]

        # HRF h
        UtilsC.expectation_H(XGamma, Q_barnCond, sigma_epsilone, Gamma, R, Sigma_H, Y,
                             y_tilde, m_A, m_H, Sigma_A, XX.astype(np.int32), J, D, M, N, scale, sigmaH)
        #m_H[0] = 0
        #m_H[-1] = 0
        # Constrain with optimization strategy
        import cvxpy as cvx
        m, n = Sigma_H.shape
        Sigma_H_inv = np.linalg.inv(Sigma_H)
        zeros_H = np.zeros_like(m_H[:, np.newaxis])
        # Construct the problem. PRIMAL
        h = cvx.Variable(n)
        expression = cvx.quad_form(h - m_H[:, np.newaxis], Sigma_H_inv)
        objective = cvx.Minimize(expression)
        #constraints = [h[0] == 0, h[-1]==0, h >= zeros_H, cvx.square(cvx.norm(h,2))<=1]
        constraints = [h[0] == 0, h[-1] == 0, cvx.square(cvx.norm(h, 2)) <= 1]
        prob = cvx.Problem(objective, constraints)
        result = prob.solve(verbose=0, solver=cvx.CVXOPT)
        # Now we update the mean of h
        m_H_old = m_H
        Sigma_H_old = Sigma_H
        m_H = np.squeeze(np.array((h.value)))
        Sigma_H = np.zeros_like(Sigma_H)
        # and the norm
        h_norm += [np.linalg.norm(m_H)]

        # crit. h
        Crit_H = (np.linalg.norm(m_H - m_H1) / np.linalg.norm(m_H1)) ** 2
        cH += [Crit_H]
        m_H1[:] = m_H[:]

        # crit. AH
        for d in xrange(0, D):
            AH[:, :, d] = m_A[:, :] * m_H[d]
        DIFF = np.reshape(AH - AH1, (M * J * D))
        Crit_AH = (np.linalg.norm(
            DIFF) / (np.linalg.norm(np.reshape(AH1, (M * J * D))) + eps)) ** 2
        cAH += [Crit_AH]
        AH1[:, :, :] = AH[:, :, :]

        # Z labels
        logger.info("E Z step ...")
        UtilsC.expectation_Z(Sigma_A, m_A, sigma_M, Beta, Z_tilde, mu_M,
                             q_Z, neighboursIndexes.astype(np.int32), M, J, K, maxNeighbours)

        # crit. Z
        DIFF = np.reshape(q_Z - q_Z1, (M * K * J))
        Crit_Z = (np.linalg.norm(DIFF) /
                  (np.linalg.norm(np.reshape(q_Z1, (M * K * J))) + eps)) ** 2
        cZ += [Crit_Z]
        q_Z1[:, :, :] = q_Z[:, :, :]

        #####################
        # MAXIMIZATION
        #####################

        # HRF: Sigma_h
        if estimateSigmaH:
            logger.info("M sigma_H step ...")
            if gamma_h > 0:
                sigmaH = vt.maximization_sigmaH_prior(
                    D, Sigma_H, R, m_H, gamma_h)
            else:
                sigmaH = vt.maximization_sigmaH(D, Sigma_H, R, m_H)
            logger.info('sigmaH = %s', str(sigmaH))

        # (mu,sigma)
        logger.info("M (mu,sigma) step ...")
        mu_M, sigma_M = vt.maximization_mu_sigma(
            mu_M, sigma_M, q_Z, m_A, K, M, Sigma_A)

        # Drift L
        UtilsC.maximization_L(
            Y, m_A, m_H, L, P, XX.astype(np.int32), J, D, M, Ndrift, N)
        PL = np.dot(P, L)
        y_tilde = Y - PL

        # Beta
        if estimateBeta:
            logger.info("estimating beta")
            for m in xrange(0, M):
                Beta[m] = UtilsC.maximization_beta(beta, q_Z[m, :, :].astype(np.float64), Z_tilde[m, :, :].astype(
                    np.float64), J, K, neighboursIndexes.astype(np.int32), gamma, maxNeighbours, MaxItGrad, gradientStep)
            logger.info("End estimating beta")
            logger.info(Beta)

        # Sigma noise
        logger.info("M sigma noise step ...")
        UtilsC.maximization_sigma_noise(
            Gamma, PL, sigma_epsilone, Sigma_H, Y, m_A, m_H, Sigma_A, XX.astype(np.int32), J, D, M, N)

        t02 = time.time()
        cTime += [t02 - t1]

    t2 = time.time()

    ##########################################################################
    # PLOTS and SNR computation

    if PLOT and 0:
        font = {'size': 15}
        matplotlib.rc('font', **font)
        savefig('./HRF_Iter_CompMod.png')
        hold(False)
        figure(2)
        plot(cAH[1:-1], 'lightblue')
        hold(True)
        plot(cFE[1:-1], 'm')
        hold(False)
        legend(('CAH', 'CFE'))
        grid(True)
        savefig('./Crit_CompMod.png')
        figure(3)
        plot(FreeEnergyArray)
        grid(True)
        savefig('./FreeEnergy_CompMod.png')

        figure(4)
        for m in xrange(M):
            plot(SUM_q_Z_array[m])
            hold(True)
        hold(False)
        savefig('./Sum_q_Z_Iter_CompMod.png')

        figure(5)
        for m in xrange(M):
            plot(mu1_array[m])
            hold(True)
        hold(False)
        savefig('./mu1_Iter_CompMod.png')

        figure(6)
        plot(h_norm_array)
        savefig('./HRF_Norm_CompMod.png')

        Data_save = xndarray(h_norm_array, ['Iteration'])
        Data_save.save('./HRF_Norm_Comp.nii')

    CompTime = t2 - t1
    cTimeMean = CompTime / ni

    """
    Norm = np.linalg.norm(m_H)
    m_H /= Norm
    Sigma_H /= Norm**2
    sigmaH /= Norm**2
    m_A *= Norm
    Sigma_A *= Norm**2
    mu_M *= Norm
    sigma_M *= Norm**2
    sigma_M = np.sqrt(np.sqrt(sigma_M))
    """
    logger.info("Nb iterations to reach criterion: %d", ni)
    logger.info("Computational time = %s min %s s", str(
        np.int(CompTime // 60)), str(np.int(CompTime % 60)))
    logger.info('mu_M: %f', mu_M)
    logger.info('sigma_M: %f', sigma_M)
    logger.info("sigma_H = %s", str(sigmaH))
    logger.info("Beta = %s", str(Beta))

    StimulusInducedSignal = vt.computeFit(m_H, m_A, X, J, N)
    SNR = 20 * \
        np.log(
            np.linalg.norm(Y) / np.linalg.norm(Y - StimulusInducedSignal - PL))
    SNR /= np.log(10.)
    print 'SNR comp =', SNR
    return ni, m_A, m_H, q_Z, sigma_epsilone, mu_M, sigma_M, Beta, L, PL, CONTRAST, CONTRASTVAR, cA[2:], cH[2:], cZ[2:], cAH[2:], cTime[2:], cTimeMean, Sigma_A, StimulusInducedSignal
Ejemplo n.º 3
0
def Main_vbjde_Extension_constrained(graph, Y, Onsets, Thrf, K, TR, beta,
                                     dt, scale=1, estimateSigmaH=True,
                                     sigmaH=0.05, NitMax=-1,
                                     NitMin=1, estimateBeta=True,
                                     PLOT=False, contrasts=[],
                                     computeContrast=False,
                                     gamma_h=0, estimateHRF=True,
                                     TrueHrfFlag=False,
                                     HrfFilename='hrf.nii',
                                     estimateLabels=True,
                                     LabelsFilename='labels.nii',
                                     MFapprox=False, InitVar=0.5,
                                     InitMean=2.0, MiniVEMFlag=False,
                                     NbItMiniVem=5):
    # VBJDE Function for BOLD with contraints

    logger.info("Fast EM with C extension started ...")
    np.random.seed(6537546)

    ##########################################################################
    # INITIALIZATIONS
    # Initialize parameters
    tau1 = 0.0
    tau2 = 0.0
    S = 100
    Init_sigmaH = sigmaH
    Nb2Norm = 1
    NormFlag = False
    if NitMax < 0:
        NitMax = 100
    gamma = 7.5
    #gamma_h = 1000
    gradientStep = 0.003
    MaxItGrad = 200
    Thresh = 1e-5
    Thresh_FreeEnergy = 1e-5
    estimateLabels = True  # WARNING!! They should be estimated

    # Initialize sizes vectors
    D = int(np.ceil(Thrf / dt)) + 1  # D = int(np.ceil(Thrf/dt))
    M = len(Onsets)
    N = Y.shape[0]
    J = Y.shape[1]
    l = int(np.sqrt(J))
    condition_names = []

    # Neighbours
    maxNeighbours = max([len(nl) for nl in graph])
    neighboursIndexes = np.zeros((J, maxNeighbours), dtype=np.int32)
    neighboursIndexes -= 1
    for i in xrange(J):
        neighboursIndexes[i, :len(graph[i])] = graph[i]
    # Conditions
    X = OrderedDict([])
    for condition, Ons in Onsets.iteritems():
        X[condition] = vt.compute_mat_X_2(N, TR, D, dt, Ons)
        condition_names += [condition]
    XX = np.zeros((M, N, D), dtype=np.int32)
    nc = 0
    for condition, Ons in Onsets.iteritems():
        XX[nc, :, :] = X[condition]
        nc += 1
    # Covariance matrix
    order = 2
    D2 = vt.buildFiniteDiffMatrix(order, D)
    R = np.dot(D2, D2) / pow(dt, 2 * order)
    invR = np.linalg.inv(R)
    Det_invR = np.linalg.det(invR)

    Gamma = np.identity(N)
    Det_Gamma = np.linalg.det(Gamma)

    p_Wtilde = np.zeros((M, K), dtype=np.float64)
    p_Wtilde1 = np.zeros((M, K), dtype=np.float64)
    p_Wtilde[:, 1] = 1

    Crit_H = 1
    Crit_Z = 1
    Crit_A = 1
    Crit_AH = 1
    AH = np.zeros((J, M, D), dtype=np.float64)
    AH1 = np.zeros((J, M, D), dtype=np.float64)
    Crit_FreeEnergy = 1

    cA = []
    cH = []
    cZ = []
    cAH = []
    FreeEnergy_Iter = []
    cTime = []
    cFE = []

    SUM_q_Z = [[] for m in xrange(M)]
    mu1 = [[] for m in xrange(M)]
    h_norm = []
    h_norm2 = []

    CONTRAST = np.zeros((J, len(contrasts)), dtype=np.float64)
    CONTRASTVAR = np.zeros((J, len(contrasts)), dtype=np.float64)
    Q_barnCond = np.zeros((M, M, D, D), dtype=np.float64)
    XGamma = np.zeros((M, D, N), dtype=np.float64)
    m1 = 0
    for k1 in X:  # Loop over the M conditions
        m2 = 0
        for k2 in X:
            Q_barnCond[m1, m2, :, :] = np.dot(
                np.dot(X[k1].transpose(), Gamma), X[k2])
            m2 += 1
        XGamma[m1, :, :] = np.dot(X[k1].transpose(), Gamma)
        m1 += 1

    if MiniVEMFlag:
        logger.info("MiniVEM to choose the best initialisation...")
        """InitVar, InitMean, gamma_h = MiniVEM_CompMod(Thrf,TR,dt,beta,Y,K,
                                                     gamma,gradientStep,
                                                     MaxItGrad,D,M,N,J,S,
                                                     maxNeighbours,
                                                     neighboursIndexes,
                                                     XX,X,R,Det_invR,Gamma,
                                                     Det_Gamma,
                                                     scale,Q_barnCond,XGamma,
                                                     NbItMiniVem,
                                                     sigmaH,estimateHRF)"""

        InitVar, InitMean, gamma_h = vt.MiniVEM_CompMod(Thrf, TR, dt, beta, Y, K, gamma, gradientStep, MaxItGrad, D, M, N, J, S, maxNeighbours,
                                                        neighboursIndexes, XX, X, R, Det_invR, Gamma, Det_Gamma, p_Wtilde, scale, Q_barnCond, XGamma, tau1, tau2, NbItMiniVem, sigmaH, estimateHRF)

    sigmaH = Init_sigmaH
    sigma_epsilone = np.ones(J)
    logger.info(
        "Labels are initialized by setting active probabilities to ones ...")
    q_Z = np.zeros((M, K, J), dtype=np.float64)
    q_Z[:, 1, :] = 1
    q_Z1 = np.zeros((M, K, J), dtype=np.float64)
    Z_tilde = q_Z.copy()

    # TT,m_h = getCanoHRF(Thrf-dt,dt) #TODO: check
    TT, m_h = getCanoHRF(Thrf, dt)  # TODO: check
    m_h = m_h[:D]
    m_H = np.array(m_h).astype(np.float64)
    m_H1 = np.array(m_h)
    sigmaH1 = sigmaH
    if estimateHRF:
        Sigma_H = np.ones((D, D), dtype=np.float64)
    else:
        Sigma_H = np.zeros((D, D), dtype=np.float64)

    Beta = beta * np.ones((M), dtype=np.float64)
    P = vt.PolyMat(N, 4, TR)
    L = vt.polyFit(Y, TR, 4, P)
    PL = np.dot(P, L)
    y_tilde = Y - PL
    Ndrift = L.shape[0]

    sigma_M = np.ones((M, K), dtype=np.float64)
    sigma_M[:, 0] = 0.5
    sigma_M[:, 1] = 0.6
    mu_M = np.zeros((M, K), dtype=np.float64)
    for k in xrange(1, K):
        mu_M[:, k] = InitMean
    Sigma_A = np.zeros((M, M, J), np.float64)
    for j in xrange(0, J):
        Sigma_A[:, :, j] = 0.01 * np.identity(M)
    m_A = np.zeros((J, M), dtype=np.float64)
    m_A1 = np.zeros((J, M), dtype=np.float64)
    for j in xrange(0, J):
        for m in xrange(0, M):
            for k in xrange(0, K):
                m_A[j, m] += np.random.normal(mu_M[m, k],
                                              np.sqrt(sigma_M[m, k])) * q_Z[m, k, j]
    m_A1 = m_A

    t1 = time.time()

    ##########################################################################
    # VBJDE num. iter. minimum

    ni = 0

    while ((ni < NitMin) or (((Crit_FreeEnergy > Thresh_FreeEnergy) or (Crit_AH > Thresh)) and (ni < NitMax))):

        logger.info("------------------------------ Iteration n° " +
                    str(ni + 1) + " ------------------------------")

        #####################
        # EXPECTATION
        #####################

        # A
        logger.info("E A step ...")
        UtilsC.expectation_A(q_Z, mu_M, sigma_M, PL, sigma_epsilone, Gamma,
                             Sigma_H, Y, y_tilde, m_A, m_H, Sigma_A, XX.astype(np.int32), J, D, M, N, K)
        val = np.reshape(m_A, (M * J))
        val[np.where((val <= 1e-50) & (val > 0.0))] = 0.0
        val[np.where((val >= -1e-50) & (val < 0.0))] = 0.0

        # crit. A
        DIFF = np.reshape(m_A - m_A1, (M * J))
        # To avoid numerical problems
        DIFF[np.where((DIFF < 1e-50) & (DIFF > 0.0))] = 0.0
        # To avoid numerical problems
        DIFF[np.where((DIFF > -1e-50) & (DIFF < 0.0))] = 0.0
        Crit_A = (
            np.linalg.norm(DIFF) / np.linalg.norm(np.reshape(m_A1, (M * J)))) ** 2
        cA += [Crit_A]
        m_A1[:, :] = m_A[:, :]

        # HRF h
        if estimateHRF:
            ################################
            #  HRF ESTIMATION
            ################################
            UtilsC.expectation_H(XGamma, Q_barnCond, sigma_epsilone, Gamma, R, Sigma_H, Y,
                                 y_tilde, m_A, m_H, Sigma_A, XX.astype(np.int32), J, D, M, N, scale, sigmaH)

            import cvxpy as cvx
            m, n = Sigma_H.shape
            Sigma_H_inv = np.linalg.inv(Sigma_H)
            zeros_H = np.zeros_like(m_H[:, np.newaxis])

            # Construct the problem. PRIMAL
            h = cvx.Variable(n)
            expression = cvx.quad_form(h - m_H[:, np.newaxis], Sigma_H_inv)
            objective = cvx.Minimize(expression)
            #constraints = [h[0] == 0, h[-1]==0, h >= zeros_H, cvx.square(cvx.norm(h,2))<=1]
            constraints = [
                h[0] == 0, h[-1] == 0, cvx.square(cvx.norm(h, 2)) <= 1]
            prob = cvx.Problem(objective, constraints)
            result = prob.solve(verbose=0, solver=cvx.CVXOPT)

            # Now we update the mean of h
            m_H_old = m_H
            Sigma_H_old = Sigma_H
            m_H = np.squeeze(np.array((h.value)))
            Sigma_H = np.zeros_like(Sigma_H)

            h_norm += [np.linalg.norm(m_H)]
            # print 'h_norm = ', h_norm

            # Plotting HRF
            if PLOT and ni >= 0:
                import matplotlib.pyplot as plt
                plt.figure(M + 1)
                plt.plot(m_H)
                plt.hold(True)
        else:
            if TrueHrfFlag:
                #TrueVal, head = read_volume(HrfFilename)
                TrueVal, head = read_volume(HrfFilename)[:, 0, 0, 0]
                print TrueVal
                print TrueVal.shape
                m_H = TrueVal

        # crit. h
        Crit_H = (np.linalg.norm(m_H - m_H1) / np.linalg.norm(m_H1)) ** 2
        cH += [Crit_H]
        m_H1[:] = m_H[:]

        # crit. AH
        for d in xrange(0, D):
            AH[:, :, d] = m_A[:, :] * m_H[d]
        DIFF = np.reshape(AH - AH1, (M * J * D))
        # To avoid numerical problems
        DIFF[np.where((DIFF < 1e-50) & (DIFF > 0.0))] = 0.0
        # To avoid numerical problems
        DIFF[np.where((DIFF > -1e-50) & (DIFF < 0.0))] = 0.0
        if np.linalg.norm(np.reshape(AH1, (M * J * D))) == 0:
            Crit_AH = 1000000000.
        else:
            Crit_AH = (
                np.linalg.norm(DIFF) / np.linalg.norm(np.reshape(AH1, (M * J * D)))) ** 2
        cAH += [Crit_AH]
        AH1[:, :, :] = AH[:, :, :]

        # Z labels
        if estimateLabels:
            logger.info("E Z step ...")
            # WARNING!!! ParsiMod gives better results, but we need the other
            # one.
            if MFapprox:
                UtilsC.expectation_Z(Sigma_A, m_A, sigma_M, Beta, Z_tilde, mu_M, q_Z, neighboursIndexes.astype(
                    np.int32), M, J, K, maxNeighbours)
            if not MFapprox:
                UtilsC.expectation_Z_ParsiMod_RVM_and_CompMod(
                    Sigma_A, m_A, sigma_M, Beta, mu_M, q_Z, neighboursIndexes.astype(np.int32), M, J, K, maxNeighbours)
        else:
            logger.info("Using True Z ...")
            TrueZ = read_volume(LabelsFilename)
            for m in xrange(M):
                q_Z[m, 1, :] = np.reshape(TrueZ[0][:, :, :, m], J)
                q_Z[m, 0, :] = 1 - q_Z[m, 1, :]

        # crit. Z
        val = np.reshape(q_Z, (M * K * J))
        val[np.where((val <= 1e-50) & (val > 0.0))] = 0.0

        DIFF = np.reshape(q_Z - q_Z1, (M * K * J))
        # To avoid numerical problems
        DIFF[np.where((DIFF < 1e-50) & (DIFF > 0.0))] = 0.0
        # To avoid numerical problems
        DIFF[np.where((DIFF > -1e-50) & (DIFF < 0.0))] = 0.0
        if np.linalg.norm(np.reshape(q_Z1, (M * K * J))) == 0:
            Crit_Z = 1000000000.
        else:
            Crit_Z = (
                np.linalg.norm(DIFF) / np.linalg.norm(np.reshape(q_Z1, (M * K * J)))) ** 2
        cZ += [Crit_Z]
        q_Z1 = q_Z

        #####################
        # MAXIMIZATION
        #####################

        # HRF: Sigma_h
        if estimateHRF:
            if estimateSigmaH:
                logger.info("M sigma_H step ...")
                if gamma_h > 0:
                    sigmaH = vt.maximization_sigmaH_prior(
                        D, Sigma_H_old, R, m_H_old, gamma_h)
                else:
                    sigmaH = vt.maximization_sigmaH(D, Sigma_H, R, m_H)
                logger.info('sigmaH = %s', str(sigmaH))

        # (mu,sigma)
        logger.info("M (mu,sigma) step ...")
        mu_M, sigma_M = vt.maximization_mu_sigma(
            mu_M, sigma_M, q_Z, m_A, K, M, Sigma_A)
        for m in xrange(M):
            SUM_q_Z[m] += [sum(q_Z[m, 1, :])]
            mu1[m] += [mu_M[m, 1]]

        # Drift L
        UtilsC.maximization_L(
            Y, m_A, m_H, L, P, XX.astype(np.int32), J, D, M, Ndrift, N)
        PL = np.dot(P, L)
        y_tilde = Y - PL

        # Beta
        if estimateBeta:
            logger.info("estimating beta")
            for m in xrange(0, M):
                if MFapprox:
                    Beta[m] = UtilsC.maximization_beta(beta, q_Z[m, :, :].astype(np.float64), Z_tilde[m, :, :].astype(
                        np.float64), J, K, neighboursIndexes.astype(np.int32), gamma, maxNeighbours, MaxItGrad, gradientStep)
                if not MFapprox:
                    #Beta[m] = UtilsC.maximization_beta(beta,q_Z[m,:,:].astype(np.float64),q_Z[m,:,:].astype(np.float64),J,K,neighboursIndexes.astype(int32),gamma,maxNeighbours,MaxItGrad,gradientStep)
                    Beta[m] = UtilsC.maximization_beta_CB(beta, q_Z[m, :, :].astype(
                        np.float64), J, K, neighboursIndexes.astype(np.int32), gamma, maxNeighbours, MaxItGrad, gradientStep)
            logger.info("End estimating beta")
            logger.info(Beta)

        # Sigma noise
        logger.info("M sigma noise step ...")
        UtilsC.maximization_sigma_noise(
            Gamma, PL, sigma_epsilone, Sigma_H, Y, m_A, m_H, Sigma_A, XX.astype(np.int32), J, D, M, N)

        #### Computing Free Energy ####
        if ni > 0:
            FreeEnergy1 = FreeEnergy

        """FreeEnergy = vt.Compute_FreeEnergy(y_tilde,m_A,Sigma_A,mu_M,sigma_M,
                                           m_H,Sigma_H,R,Det_invR,sigmaH,
                                           p_Wtilde,q_Z,neighboursIndexes,
                                           maxNeighbours,Beta,sigma_epsilone,
                                           XX,Gamma,Det_Gamma,XGamma,J,D,M,
                                           N,K,S,"CompMod")"""
        FreeEnergy = vt.Compute_FreeEnergy(y_tilde, m_A, Sigma_A, mu_M, sigma_M, m_H, Sigma_H, R, Det_invR, sigmaH, p_Wtilde, tau1,
                                           tau2, q_Z, neighboursIndexes, maxNeighbours, Beta, sigma_epsilone, XX, Gamma, Det_Gamma, XGamma, J, D, M, N, K, S, "CompMod")

        if ni > 0:
            Crit_FreeEnergy = (FreeEnergy1 - FreeEnergy) / FreeEnergy1
        FreeEnergy_Iter += [FreeEnergy]
        cFE += [Crit_FreeEnergy]

        # Update index
        ni += 1

        t02 = time.time()
        cTime += [t02 - t1]

    t2 = time.time()

    ##########################################################################
    # PLOTS and SNR computation

    FreeEnergyArray = np.zeros((ni), dtype=np.float64)
    for i in xrange(ni):
        FreeEnergyArray[i] = FreeEnergy_Iter[i]

    SUM_q_Z_array = np.zeros((M, ni), dtype=np.float64)
    mu1_array = np.zeros((M, ni), dtype=np.float64)
    h_norm_array = np.zeros((ni), dtype=np.float64)
    for m in xrange(M):
        for i in xrange(ni):
            SUM_q_Z_array[m, i] = SUM_q_Z[m][i]
            mu1_array[m, i] = mu1[m][i]
            h_norm_array[i] = h_norm[i]

    if PLOT and 0:
        import matplotlib.pyplot as plt
        import matplotlib
        font = {'size': 15}
        matplotlib.rc('font', **font)
        plt.savefig('./HRF_Iter_CompMod.png')
        plt.hold(False)
        plt.figure(2)
        plt.plot(cAH[1:-1], 'lightblue')
        plt.hold(True)
        plt.plot(cFE[1:-1], 'm')
        plt.hold(False)
        #plt.legend( ('CA','CH', 'CZ', 'CAH', 'CFE') )
        plt.legend(('CAH', 'CFE'))
        plt.grid(True)
        plt.savefig('./Crit_CompMod.png')
        plt.figure(3)
        plt.plot(FreeEnergyArray)
        plt.grid(True)
        plt.savefig('./FreeEnergy_CompMod.png')

        plt.figure(4)
        for m in xrange(M):
            plt.plot(SUM_q_Z_array[m])
            plt.hold(True)
        plt.hold(False)
        #plt.legend( ('m=0','m=1', 'm=2', 'm=3') )
        #plt.legend( ('m=0','m=1') )
        plt.savefig('./Sum_q_Z_Iter_CompMod.png')

        plt.figure(5)
        for m in xrange(M):
            plt.plot(mu1_array[m])
            plt.hold(True)
        plt.hold(False)
        plt.savefig('./mu1_Iter_CompMod.png')

        plt.figure(6)
        plt.plot(h_norm_array)
        plt.savefig('./HRF_Norm_CompMod.png')

        Data_save = xndarray(h_norm_array, ['Iteration'])
        Data_save.save('./HRF_Norm_Comp.nii')

    CompTime = t2 - t1
    cTimeMean = CompTime / ni

    sigma_M = np.sqrt(np.sqrt(sigma_M))
    logger.info("Nb iterations to reach criterion: %d", ni)
    logger.info("Computational time = %s min %s s", str(
        int(CompTime // 60)), str(int(CompTime % 60)))
    # print "Computational time = " + str(int( CompTime//60 ) ) + " min " + str(int(CompTime%60)) + " s"
    # print "sigma_H = " + str(sigmaH)
    logger.info('mu_M: %f', mu_M)
    logger.info('sigma_M: %f', sigma_M)
    logger.info("sigma_H = %s" + str(sigmaH))
    logger.info("Beta = %s" + str(Beta))

    StimulusInducedSignal = vt.computeFit(m_H, m_A, X, J, N)
    SNR = 20 * \
        np.log(
            np.linalg.norm(Y) / np.linalg.norm(Y - StimulusInducedSignal - PL))
    SNR /= np.log(10.)
    logger.info("SNR = %d", SNR)
    return ni, m_A, m_H, q_Z, sigma_epsilone, mu_M, sigma_M, Beta, L, PL, CONTRAST, CONTRASTVAR, cA[2:], cH[2:], cZ[2:], cAH[2:], cTime[2:], cTimeMean, Sigma_A, StimulusInducedSignal, FreeEnergyArray
Ejemplo n.º 4
0
def jde_vem_bold(graph, bold_data, onsets, durations, hrf_duration, nb_classes,
                 tr, beta, dt, estimate_sigma_h=True, sigma_h=0.05,
                 it_max=-1, it_min=0, estimate_beta=True, contrasts=None,
                 compute_contrasts=False, hrf_hyperprior=0, estimate_hrf=True,
                 constrained=False, zero_constraint=True, drifts_type="poly",
                 seed=6537546):
    """This is the main function that computes the VEM analysis on BOLD data.
    This function uses optimized python functions.

    Parameters
    ----------
    graph : ndarray of lists
        represents the neighbours indexes of each voxels index
    bold_data : ndarray, shape (nb_scans, nb_voxels)
        raw data
    onsets : dict
        dictionnary of onsets
    durations : # TODO
        # TODO
    hrf_duration : float
        hrf total time duration (in s)
    nb_classes : int
        the number of classes to classify the nrls. This parameter is provided
        for development purposes as most of the algorithm implies two classes
    tr : float
        time of repetition
    beta : float
        the initial value of beta
    dt : float
        hrf temporal precision
    estimate_sigma_h : bool, optional
        toggle estimation of sigma H
    sigma_h : float, optional
        initial or fixed value of sigma H
    it_max : int, optional
        maximal computed iteration number
    it_min : int, optional
        minimal computed iteration number
    estimate_beta : bool, optional
        toggle the estimation of beta
    contrasts : OrderedDict, optional
        dict of contrasts to compute
    compute_contrasts : bool, optional
        if True, compute the contrasts defined in contrasts
    hrf_hyperprior : float
        # TODO
    estimate_hrf : bool, optional
        if True, estimate the HRF for each parcel, if False use the canonical HRF
    constrained : bool, optional
        if True, add a constrains the l2 norm of the HRF to 1
    drifts_type : str, optional
        set the drifts basis type used. Can be "poly" for polynomial or "cos"
        for cosine
    seed : int, optional
        seed used by numpy to initialize random generator number

    Returns
    -------
    loop : int
        number of iterations before convergence
    nrls_mean : ndarray, shape (nb_voxels, nb_conditions)
        Neural response level mean value
    hrf_mean : ndarray, shape (hrf_len,)
        Hemodynamic response function mean value
    hrf_covar : ndarray, shape (hrf_len, hrf_len)
        Covariance matrix of the HRF
    labels_proba : ndarray, shape (nb_conditions, nb_classes, nb_voxels)
        probability of voxels being in one class
    noise_var : ndarray, shape (nb_voxels,)
        estimated noise variance
    nrls_class_mean : ndarray, shape (nb_conditions, nb_classes)
        estimated mean value of the gaussians of the classes
    nrls_class_var : ndarray, shape (nb_conditions, nb_classes)
        estimated variance of the gaussians of the classes
    beta : ndarray, shape (nb_conditions,)
        estimated beta
    drift_coeffs : ndarray, shape (# TODO)
        estimated coefficient of the drifts
    drift : ndarray, shape (# TODO)
        estimated drifts
    contrasts_mean : ndarray, shape (nb_voxels, len(contrasts))
        Contrasts computed from NRLs
    contrasts_var : ndarray, shape (nb_voxels, len(contrasts))
        Variance of the contrasts
    compute_time : list
        computation time of each iteration
    compute_time_mean : float
        computation mean time over iterations
    nrls_covar : ndarray, shape (nb_conditions, nb_conditions, nb_voxels)
        # TODO
    stimulus_induced_signal : ndarray, shape (nb_scans, nb_voxels)
        # TODO
    mahalanobis_zero : float
        Mahalanobis distance between estimated hrf_mean and the null vector
    mahalanobis_cano : float
        Mahalanobis distance between estimated hrf_mean and the canonical HRF
    mahalanobis_diff : float
        difference between mahalanobis_cano and mahalanobis_diff
    mahalanobis_prod : float
        product of mahalanobis_cano and mahalanobis_diff
    ppm_a_nrl : ndarray, shape (nb_voxels,)
        The posterior probability map using an alpha
    ppm_g_nrl : ndarray, shape (nb_voxels,)
        # TODO
    ppm_a_contrasts : ndarray, shape (nb_voxels,)
        # TODO
    ppm_g_contrasts : ndarray, shape (nb_voxels,)
        # TODO
    variation_coeff : float
        coefficient of variation of the HRF
    free_energy : list
        # TODO

    Notes
    -----
        See `A novel definition of the multivariate coefficient of variation
        <http://onlinelibrary.wiley.com/doi/10.1002/bimj.201000030/abstract>`_
        article for more information about the coefficient of variation.
    """

    logger.info("VEM started.")

    if not contrasts:
        contrasts = OrderedDict()

    np.random.seed(seed)

    nb_2_norm = 1
    normalizing = False
    regularizing = False

    if it_max <= 0:
        it_max = 100
    gamma = 7.5
    thresh_free_energy = 1e-4

    # Initialize sizes vectors
    hrf_len = np.int(np.ceil(hrf_duration / dt)) + 1
    nb_conditions = len(onsets)
    nb_scans = bold_data.shape[0]
    nb_voxels = bold_data.shape[1]
    X, occurence_matrix, condition_names = vt.create_conditions(
        onsets, durations, nb_conditions, nb_scans, hrf_len, tr, dt
    )
    neighbours_indexes = vt.create_neighbours(graph)

    order = 2
    if regularizing:
        regularization = np.ones(hrf_len)
        regularization[hrf_len//3:hrf_len//2] = 2
        regularization[hrf_len//2:2*hrf_len//3] = 5
        regularization[2*hrf_len//3:3*hrf_len//4] = 7
        regularization[3*hrf_len//4:] = 10
        # regularization[hrf_len//2:] = 10
    else:
        regularization = None
    d2 = vt.buildFiniteDiffMatrix(order, hrf_len, regularization)
    hrf_regu_prior_inv = d2.T.dot(d2) / pow(dt, 2 * order)

    if estimate_hrf and zero_constraint:
        hrf_len = hrf_len - 2
        hrf_regu_prior_inv = hrf_regu_prior_inv[1:-1, 1:-1]
        occurence_matrix = occurence_matrix[:, :, 1:-1]

    noise_struct = np.identity(nb_scans)

    free_energy = [1.]
    free_energy_crit = [1.]
    compute_time = []

    noise_var = np.ones(nb_voxels)

    labels_proba = np.zeros((nb_conditions, nb_classes, nb_voxels), dtype=np.float64)
    logger.info("Labels are initialized by setting everything to {}".format(1./nb_classes))
    labels_proba[:, :, :] = 1./nb_classes

    m_h = getCanoHRF(hrf_duration, dt)[1][:hrf_len]
    hrf_mean = np.array(m_h).astype(np.float64)
    if estimate_hrf:
        hrf_covar = np.identity(hrf_len, dtype=np.float64)
    else:
        hrf_covar = np.zeros((hrf_len, hrf_len), dtype=np.float64)

    beta = beta * np.ones((nb_conditions), dtype=np.float64)
    beta_list = []
    beta_list.append(beta.copy())
    if drifts_type == "poly":
        drift_basis = vt.poly_drifts_basis(nb_scans, 4, tr)
    elif drifts_type == "cos":
        drift_basis = vt.cosine_drifts_basis(nb_scans, 64, tr)
    drift_coeffs = vt.drifts_coeffs_fit(bold_data, drift_basis)
    drift = drift_basis.dot(drift_coeffs)
    bold_data_drift = bold_data - drift

    # Parameters Gaussian mixtures
    nrls_class_mean = 2 * np.ones((nb_conditions, nb_classes))
    nrls_class_mean[:, 0] = 0
    nrls_class_var = 0.3 * np.ones((nb_conditions, nb_classes), dtype=np.float64)

    nrls_mean = (np.random.normal(
        nrls_class_mean, nrls_class_var)[:, :, np.newaxis] * labels_proba).sum(axis=1).T
    nrls_covar = (np.identity(nb_conditions)[:, :, np.newaxis] + np.zeros((1, 1, nb_voxels)))

    start_time = time.time()
    loop = 0
    while (loop <= it_min or
           ((np.asarray(free_energy_crit[-5:]) > thresh_free_energy).any()
            and loop < it_max)):

        logger.info("{:-^80}".format(" Iteration n°"+str(loop+1)+" "))

        logger.info("Expectation A step...")
        logger.debug("Before: nrls_mean = %s, nrls_covar = %s", nrls_mean, nrls_covar)
        nrls_mean, nrls_covar = vt.nrls_expectation(
            hrf_mean, nrls_mean, occurence_matrix, noise_struct, labels_proba,
            nrls_class_mean, nrls_class_var, nb_conditions, bold_data_drift, nrls_covar,
            hrf_covar, noise_var)
        logger.debug("After: nrls_mean = %s, nrls_covar = %s", nrls_mean, nrls_covar)

        logger.info("Expectation Z step...")
        logger.debug("Before: labels_proba = %s, labels_proba = %s", labels_proba, labels_proba)
        labels_proba = vt.labels_expectation(
            nrls_covar, nrls_mean, nrls_class_var, nrls_class_mean, beta,
            labels_proba, neighbours_indexes, nb_conditions, nb_classes,
            nb_voxels, parallel=True)
        logger.debug("After: labels_proba = %s, labels_proba = %s", labels_proba, labels_proba)

        if estimate_hrf:
            logger.info("Expectation H step...")
            logger.debug("Before: hrf_mean = %s, hrf_covar = %s", hrf_mean, hrf_covar)
            hrf_mean, hrf_covar = vt.hrf_expectation(
                nrls_covar, nrls_mean, occurence_matrix, noise_struct,
                hrf_regu_prior_inv, sigma_h, nb_voxels, bold_data_drift, noise_var)
            if constrained:
                hrf_mean = vt.norm1_constraint(hrf_mean, hrf_covar)
                hrf_covar[:] = 0
            logger.debug("After: hrf_mean = %s, hrf_covar = %s", hrf_mean, hrf_covar)
            # Normalizing H at each nb_2_norm iterations:
            if not constrained and normalizing:
                # Normalizing is done before sigma_h, nrls_class_mean and nrls_class_var estimation
                # we should not include them in the normalisation step
                if (loop + 1) % nb_2_norm == 0:
                    hrf_norm = np.linalg.norm(hrf_mean)
                    hrf_mean /= hrf_norm
                    hrf_covar /= hrf_norm ** 2
                    nrls_mean *= hrf_norm
                    nrls_covar *= hrf_norm ** 2

        if estimate_hrf and estimate_sigma_h:
            logger.info("Maximization sigma_H step...")
            logger.debug("Before: sigma_h = %s", sigma_h)
            if hrf_hyperprior > 0:
                sigma_h = vt.maximization_sigmaH_prior(hrf_len, hrf_covar,
                                                       hrf_regu_prior_inv,
                                                       hrf_mean, hrf_hyperprior)
            else:
                sigma_h = vt.maximization_sigmaH(hrf_len, hrf_covar,
                                                 hrf_regu_prior_inv, hrf_mean)
            logger.debug("After: sigma_h = %s", sigma_h)

        logger.info("Maximization (mu,sigma) step...")
        logger.debug("Before: nrls_class_mean = %s, nrls_class_var = %s",
                     nrls_class_mean, nrls_class_var)
        nrls_class_mean, nrls_class_var = vt.maximization_class_proba(
            labels_proba, nrls_mean, nrls_covar
        )
        logger.debug("After: nrls_class_mean = %s, nrls_class_var = %s",
                     nrls_class_mean, nrls_class_var)

        logger.info("Maximization L step...")
        logger.debug("Before: drift_coeffs = %s", drift_coeffs)
        drift_coeffs = vt.maximization_drift_coeffs(
            bold_data, nrls_mean, occurence_matrix, hrf_mean, noise_struct, drift_basis
        )
        logger.debug("After: drift_coeffs = %s", drift_coeffs)

        drift = drift_basis.dot(drift_coeffs)
        bold_data_drift = bold_data - drift
        if estimate_beta:
            logger.info("Maximization beta step...")
            for cond_nb in xrange(0, nb_conditions):
                beta[cond_nb], success = vt.beta_maximization(
                    beta[cond_nb]*np.ones((1,)), labels_proba[cond_nb, :, :],
                    neighbours_indexes, gamma
                )
            beta_list.append(beta.copy())
            logger.debug("beta = %s", str(beta))

        logger.info("Maximization sigma noise step...")
        noise_var = vt.maximization_noise_var(
            occurence_matrix, hrf_mean, hrf_covar, nrls_mean, nrls_covar,
            noise_struct, bold_data_drift, nb_scans
        )

        #### Computing Free Energy ####
        free_energy.append(vt.free_energy_computation(
            nrls_mean, nrls_covar, hrf_mean, hrf_covar, hrf_len, labels_proba,
            bold_data_drift, occurence_matrix, noise_var, noise_struct, nb_conditions,
            nb_voxels, nb_scans, nb_classes, nrls_class_mean, nrls_class_var, neighbours_indexes,
            beta, sigma_h, np.linalg.inv(hrf_regu_prior_inv), hrf_regu_prior_inv, gamma, hrf_hyperprior
        ))
        free_energy_crit.append(abs((free_energy[-2] - free_energy[-1]) /
                                    free_energy[-2]))

        logger.info("Convergence criteria: %f (Threshold = %f)",
                    free_energy_crit[-1], thresh_free_energy)
        loop += 1
        compute_time.append(time.time() - start_time)

    compute_time_mean = compute_time[-1] / loop

    mahalanobis_zero = np.nan
    mahalanobis_cano = np.nan
    mahalanobis_diff = np.nan
    mahalanobis_prod = np.nan
    variation_coeff = np.nan

    if estimate_hrf and not constrained and not normalizing:
        hrf_norm = np.linalg.norm(hrf_mean)
        hrf_mean /= hrf_norm
        hrf_covar /= hrf_norm ** 2
        sigma_h /= hrf_norm ** 2
        nrls_mean *= hrf_norm
        nrls_covar *= hrf_norm ** 2
        nrls_class_mean *= hrf_norm
        nrls_class_var *= hrf_norm ** 2
        mahalanobis_zero = mahalanobis(hrf_mean, np.zeros_like(hrf_mean),
                                       np.linalg.inv(hrf_covar))
        mahalanobis_cano = mahalanobis(hrf_mean, m_h, np.linalg.inv(hrf_covar))
        mahalanobis_diff = mahalanobis_cano - mahalanobis_zero
        mahalanobis_prod = mahalanobis_cano * mahalanobis_zero
        variation_coeff = np.sqrt((hrf_mean.T.dot(hrf_covar).dot(hrf_mean))
                                  /(hrf_mean.T.dot(hrf_mean))**2)
    if estimate_hrf and zero_constraint:
        hrf_mean = np.concatenate(([0], hrf_mean, [0]))
        # when using the zero constraint the hrf covariance is fill with
        # arbitrary zeros around the matrix, this is maybe a bad idea if we need
        # it for later computation...
        hrf_covar = np.concatenate(
            (np.zeros((hrf_covar.shape[0], 1)), hrf_covar, np.zeros((hrf_covar.shape[0], 1))),
            axis=1
        )
        hrf_covar = np.concatenate(
            (np.zeros((1, hrf_covar.shape[1])), hrf_covar, np.zeros((1, hrf_covar.shape[1]))),
            axis=0
        )

    if estimate_hrf:
        (delay_of_response, delay_of_undershoot, dispersion_of_response,
         dispersion_of_undershoot, ratio_resp_under, delay) = vt.fit_hrf_two_gammas(
             hrf_mean, dt, hrf_duration
         )
    else:
        (delay_of_response, delay_of_undershoot, dispersion_of_response,
         dispersion_of_undershoot, ratio_resp_under, delay) = (None, None, None,
                                                               None, None, None)

    ppm_a_nrl, ppm_g_nrl = vt.ppms_computation(
        nrls_mean, np.diagonal(nrls_covar), nrls_class_mean, nrls_class_var,
        threshold_a="intersect"
    )

    #+++++++++++++++++++++++  calculate contrast maps and variance +++++++++++++++++++++++#

    nb_contrasts = len(contrasts)
    if compute_contrasts and nb_contrasts > 0:
        logger.info('Computing contrasts ...')
        (contrasts_mean,
         contrasts_var,
         contrasts_class_mean,
         contrasts_class_var) = vt.contrasts_mean_var_classes(
             contrasts, condition_names, nrls_mean, nrls_covar,
             nrls_class_mean, nrls_class_var, nb_contrasts, nb_classes, nb_voxels
         )
        ppm_a_contrasts, ppm_g_contrasts = vt.ppms_computation(
            contrasts_mean, contrasts_var, contrasts_class_mean, contrasts_class_var
        )
        logger.info('Done computing contrasts.')
    else:
        (contrasts_mean, contrasts_var, contrasts_class_mean,
         contrasts_class_var, ppm_a_contrasts, ppm_g_contrasts) = (None, None,
                                                                   None, None,
                                                                   None, None)

    #+++++++++++++++++++++++  calculate contrast maps and variance  +++++++++++++++++++++++#

    logger.info("Nb iterations to reach criterion: %d", loop)
    logger.info("Computational time = %s min %s s",
                *(str(int(x)) for x in divmod(compute_time[-1], 60)))
    logger.debug('nrls_class_mean: %s', nrls_class_mean)
    logger.debug('nrls_class_var: %s', nrls_class_var)
    logger.debug("sigma_H = %s", str(sigma_h))
    logger.debug("beta = %s", str(beta))

    stimulus_induced_signal = vt.computeFit(hrf_mean, nrls_mean, X, nb_voxels, nb_scans)
    snr = 20 * np.log(
        np.linalg.norm(bold_data.astype(np.float))
        / np.linalg.norm((bold_data - stimulus_induced_signal - drift).astype(np.float))
    )
    snr /= np.log(10.)
    logger.info('snr comp = %f', snr)
    # ,FreeEnergyArray
    return (loop, nrls_mean, hrf_mean, hrf_covar, labels_proba, noise_var,
            nrls_class_mean, nrls_class_var, beta, drift_coeffs, drift,
            contrasts_mean, contrasts_var, compute_time[2:], compute_time_mean,
            nrls_covar, stimulus_induced_signal, mahalanobis_zero,
            mahalanobis_cano, mahalanobis_diff, mahalanobis_prod, ppm_a_nrl,
            ppm_g_nrl, ppm_a_contrasts, ppm_g_contrasts, variation_coeff,
            free_energy[1:], free_energy_crit[1:], beta_list[1:],
            delay_of_response, delay_of_undershoot, dispersion_of_response,
            dispersion_of_undershoot, ratio_resp_under, delay)
Ejemplo n.º 5
0
def jde_vem_bold(graph,
                 bold_data,
                 onsets,
                 durations,
                 hrf_duration,
                 nb_classes,
                 tr,
                 beta,
                 dt,
                 estimate_sigma_h=True,
                 sigma_h=0.05,
                 it_max=-1,
                 it_min=0,
                 estimate_beta=True,
                 contrasts=None,
                 compute_contrasts=False,
                 hrf_hyperprior=0,
                 estimate_hrf=True,
                 constrained=False,
                 zero_constraint=True,
                 drifts_type="poly",
                 seed=6537546):
    """This is the main function that computes the VEM analysis on BOLD data.
    This function uses optimized python functions.

    Parameters
    ----------
    graph : ndarray of lists
        represents the neighbours indexes of each voxels index
    bold_data : ndarray, shape (nb_scans, nb_voxels)
        raw data
    onsets : dict
        dictionnary of onsets
    durations : # TODO
        # TODO
    hrf_duration : float
        hrf total time duration (in s)
    nb_classes : int
        the number of classes to classify the nrls. This parameter is provided
        for development purposes as most of the algorithm implies two classes
    tr : float
        time of repetition
    beta : float
        the initial value of beta
    dt : float
        hrf temporal precision
    estimate_sigma_h : bool, optional
        toggle estimation of sigma H
    sigma_h : float, optional
        initial or fixed value of sigma H
    it_max : int, optional
        maximal computed iteration number
    it_min : int, optional
        minimal computed iteration number
    estimate_beta : bool, optional
        toggle the estimation of beta
    contrasts : OrderedDict, optional
        dict of contrasts to compute
    compute_contrasts : bool, optional
        if True, compute the contrasts defined in contrasts
    hrf_hyperprior : float
        # TODO
    estimate_hrf : bool, optional
        if True, estimate the HRF for each parcel, if False use the canonical HRF
    constrained : bool, optional
        if True, add a constrains the l2 norm of the HRF to 1
    zero_constraint : bool, optional
        if True, add zeros to the beginning and the end of the estimated HRF.
    drifts_type : str, optional
        set the drifts basis type used. Can be "poly" for polynomial or "cos"
        for cosine
    seed : int, optional
        seed used by numpy to initialize random generator number

    Returns
    -------
    loop : int
        number of iterations before convergence
    nrls_mean : ndarray, shape (nb_voxels, nb_conditions)
        Neural response level mean value
    hrf_mean : ndarray, shape (hrf_len,)
        Hemodynamic response function mean value
    hrf_covar : ndarray, shape (hrf_len, hrf_len)
        Covariance matrix of the HRF
    labels_proba : ndarray, shape (nb_conditions, nb_classes, nb_voxels)
        probability of voxels being in one class
    noise_var : ndarray, shape (nb_voxels,)
        estimated noise variance
    nrls_class_mean : ndarray, shape (nb_conditions, nb_classes)
        estimated mean value of the gaussians of the classes
    nrls_class_var : ndarray, shape (nb_conditions, nb_classes)
        estimated variance of the gaussians of the classes
    beta : ndarray, shape (nb_conditions,)
        estimated beta
    drift_coeffs : ndarray, shape (# TODO)
        estimated coefficient of the drifts
    drift : ndarray, shape (# TODO)
        estimated drifts
    contrasts_mean : ndarray, shape (nb_voxels, len(contrasts))
        Contrasts computed from NRLs
    contrasts_var : ndarray, shape (nb_voxels, len(contrasts))
        Variance of the contrasts
    compute_time : list
        computation time of each iteration
    compute_time_mean : float
        computation mean time over iterations
    nrls_covar : ndarray, shape (nb_conditions, nb_conditions, nb_voxels)
        # TODO
    stimulus_induced_signal : ndarray, shape (nb_scans, nb_voxels)
        # TODO
    mahalanobis_zero : float
        Mahalanobis distance between estimated hrf_mean and the null vector
    mahalanobis_cano : float
        Mahalanobis distance between estimated hrf_mean and the canonical HRF
    mahalanobis_diff : float
        difference between mahalanobis_cano and mahalanobis_diff
    mahalanobis_prod : float
        product of mahalanobis_cano and mahalanobis_diff
    ppm_a_nrl : ndarray, shape (nb_voxels,)
        The posterior probability map using an alpha
    ppm_g_nrl : ndarray, shape (nb_voxels,)
        # TODO
    ppm_a_contrasts : ndarray, shape (nb_voxels,)
        # TODO
    ppm_g_contrasts : ndarray, shape (nb_voxels,)
        # TODO
    variation_coeff : float
        coefficient of variation of the HRF
    free_energy : list
        # TODO

    Notes
    -----
        See `A novel definition of the multivariate coefficient of variation
        <http://onlinelibrary.wiley.com/doi/10.1002/bimj.201000030/abstract>`_
        article for more information about the coefficient of variation.
    """

    logger.info("VEM started.")

    if not contrasts:
        contrasts = OrderedDict()

    np.random.seed(seed)

    nb_2_norm = 1
    normalizing = False
    regularizing = False

    if it_max <= 0:
        it_max = 100

    gamma = 7.5

    # Initialize sizes vectors
    hrf_len = np.int(np.ceil(hrf_duration / dt)) + 1

    nb_conditions = len(onsets)
    nb_scans = bold_data.shape[0]
    nb_voxels = bold_data.shape[1]
    X, occurence_matrix, condition_names = vt.create_conditions(
        onsets, durations, nb_conditions, nb_scans, hrf_len, tr, dt)

    neighbours_indexes = vt.create_neighbours(graph)

    order = 2
    if regularizing:
        regularization = np.ones(hrf_len)
        regularization[hrf_len // 3:hrf_len // 2] = 2
        regularization[hrf_len // 2:2 * hrf_len // 3] = 5
        regularization[2 * hrf_len // 3:3 * hrf_len // 4] = 7
        regularization[3 * hrf_len // 4:] = 10
        # regularization[hrf_len//2:] = 10
    else:
        regularization = None

    d2 = vt.buildFiniteDiffMatrix(order, hrf_len, regularization)
    hrf_regu_prior_inv = d2.T.dot(d2) / pow(dt, 2 * order)

    if estimate_hrf and zero_constraint:
        hrf_len = hrf_len - 2
        hrf_regu_prior_inv = hrf_regu_prior_inv[1:-1, 1:-1]
        occurence_matrix = occurence_matrix[:, :, 1:-1]

    noise_struct = np.identity(nb_scans)

    noise_var = np.ones(nb_voxels)

    if nb_classes != 2:
        logger.warn('The number of classes is different to two.')

    labels_proba = np.zeros((nb_conditions, nb_classes, nb_voxels),
                            dtype=np.float64)
    logger.info("Labels are initialized by setting everything to {}".format(
        1. / nb_classes))
    labels_proba[:, :, :] = 1. / nb_classes

    m_h = getCanoHRF(hrf_duration, dt)[1][:hrf_len]
    hrf_mean = np.array(m_h).astype(np.float64)

    if estimate_hrf:
        hrf_covar = np.identity(hrf_len, dtype=np.float64)
    else:
        hrf_covar = np.zeros((hrf_len, hrf_len), dtype=np.float64)

    beta = beta * np.ones(nb_conditions, dtype=np.float64)
    beta_list = [beta.copy()]

    if drifts_type == "poly":
        drift_basis = vt.poly_drifts_basis(nb_scans, 4, tr)
    elif drifts_type == "cos":
        drift_basis = vt.cosine_drifts_basis(nb_scans, 64, tr)
    else:
        raise Exception('drift type "%s" is not supported' % drifts_type)

    drift_coeffs = vt.drifts_coeffs_fit(bold_data, drift_basis)
    drift = drift_basis.dot(drift_coeffs)
    bold_data_drift = bold_data - drift

    # Parameters Gaussian mixtures
    nrls_class_mean = 2 * np.ones((nb_conditions, nb_classes))
    nrls_class_mean[:, 0] = 0

    nrls_class_var = 0.3 * np.ones(
        (nb_conditions, nb_classes), dtype=np.float64)

    nrls_mean = (
        np.random.normal(nrls_class_mean, nrls_class_var)[:, :, np.newaxis] *
        labels_proba).sum(axis=1).T

    nrls_covar = np.identity(nb_conditions)[:, :, np.newaxis] + np.zeros(
        (1, 1, nb_voxels))

    thresh_free_energy = 1e-4
    free_energy = [1.]
    free_energy_crit = [1.]

    compute_time = []
    start_time = time.time()
    loop = 0
    while (loop <= it_min
           or ((np.asarray(free_energy_crit[-5:]) > thresh_free_energy).any()
               and loop < it_max)):

        logger.info("{:-^80}".format(" Iteration n°" + str(loop + 1) + " "))

        logger.info("Expectation A step...")
        logger.debug("Before: nrls_mean = %s, nrls_covar = %s", nrls_mean,
                     nrls_covar)
        nrls_mean, nrls_covar = vt.nrls_expectation(
            hrf_mean, nrls_mean, occurence_matrix, noise_struct, labels_proba,
            nrls_class_mean, nrls_class_var, nb_conditions, bold_data_drift,
            nrls_covar, hrf_covar, noise_var)
        logger.debug("After: nrls_mean = %s, nrls_covar = %s", nrls_mean,
                     nrls_covar)

        logger.info("Expectation Z step...")
        logger.debug("Before: labels_proba = %s, labels_proba = %s",
                     labels_proba, labels_proba)
        labels_proba = vt.labels_expectation(nrls_covar,
                                             nrls_mean,
                                             nrls_class_var,
                                             nrls_class_mean,
                                             beta,
                                             labels_proba,
                                             neighbours_indexes,
                                             nb_conditions,
                                             nb_classes,
                                             nb_voxels,
                                             parallel=True)
        logger.debug("After: labels_proba = %s, labels_proba = %s",
                     labels_proba, labels_proba)

        if estimate_hrf:
            logger.info("Expectation H step...")
            logger.debug("Before: hrf_mean = %s, hrf_covar = %s", hrf_mean,
                         hrf_covar)
            hrf_mean, hrf_covar = vt.hrf_expectation(
                nrls_covar, nrls_mean, occurence_matrix, noise_struct,
                hrf_regu_prior_inv, sigma_h, nb_voxels, bold_data_drift,
                noise_var)

            if constrained:
                hrf_mean = vt.norm1_constraint(hrf_mean, hrf_covar)
                hrf_covar[:] = 0

            logger.debug("After: hrf_mean = %s, hrf_covar = %s", hrf_mean,
                         hrf_covar)

            # Normalizing H at each nb_2_norm iterations:
            if not constrained and normalizing:
                # Normalizing is done before sigma_h, nrls_class_mean and nrls_class_var estimation
                # we should not include them in the normalisation step
                if (loop + 1) % nb_2_norm == 0:
                    hrf_norm = np.linalg.norm(hrf_mean)
                    hrf_mean /= hrf_norm
                    hrf_covar /= hrf_norm**2
                    nrls_mean *= hrf_norm
                    nrls_covar *= hrf_norm**2

        if estimate_hrf and estimate_sigma_h:
            logger.info("Maximization sigma_H step...")
            logger.debug("Before: sigma_h = %s", sigma_h)
            if hrf_hyperprior > 0:
                sigma_h = vt.maximization_sigmaH_prior(hrf_len, hrf_covar,
                                                       hrf_regu_prior_inv,
                                                       hrf_mean,
                                                       hrf_hyperprior)
            else:
                sigma_h = vt.maximization_sigmaH(hrf_len, hrf_covar,
                                                 hrf_regu_prior_inv, hrf_mean)
            logger.debug("After: sigma_h = %s", sigma_h)

        logger.info("Maximization (mu,sigma) step...")
        logger.debug("Before: nrls_class_mean = %s, nrls_class_var = %s",
                     nrls_class_mean, nrls_class_var)
        nrls_class_mean, nrls_class_var = vt.maximization_class_proba(
            labels_proba, nrls_mean, nrls_covar)
        logger.debug("After: nrls_class_mean = %s, nrls_class_var = %s",
                     nrls_class_mean, nrls_class_var)

        logger.info("Maximization L step...")
        logger.debug("Before: drift_coeffs = %s", drift_coeffs)
        drift_coeffs = vt.maximization_drift_coeffs(bold_data, nrls_mean,
                                                    occurence_matrix, hrf_mean,
                                                    noise_struct, drift_basis)
        logger.debug("After: drift_coeffs = %s", drift_coeffs)

        drift = drift_basis.dot(drift_coeffs)
        bold_data_drift = bold_data - drift
        if estimate_beta:
            logger.info("Maximization beta step...")
            for cond_nb in xrange(0, nb_conditions):
                beta[cond_nb], success = vt.beta_maximization(
                    beta[cond_nb] * np.ones((1, )),
                    labels_proba[cond_nb, :, :], neighbours_indexes, gamma)
            beta_list.append(beta.copy())
            logger.debug("beta = %s", str(beta))

        logger.info("Maximization sigma noise step...")
        noise_var = vt.maximization_noise_var(occurence_matrix, hrf_mean,
                                              hrf_covar, nrls_mean, nrls_covar,
                                              noise_struct, bold_data_drift,
                                              nb_scans)

        # Computing Free Energy
        free_energy.append(
            vt.free_energy_computation(
                nrls_mean, nrls_covar, hrf_mean, hrf_covar, hrf_len,
                labels_proba, bold_data_drift, occurence_matrix, noise_var,
                noise_struct, nb_conditions, nb_voxels, nb_scans, nb_classes,
                nrls_class_mean, nrls_class_var,
                neighbours_indexes, beta, sigma_h,
                np.linalg.inv(hrf_regu_prior_inv), hrf_regu_prior_inv, gamma,
                hrf_hyperprior))

        free_energy_crit.append(
            abs((free_energy[-2] - free_energy[-1]) / free_energy[-2]))

        logger.info("Convergence criteria: %f (Threshold = %f)",
                    free_energy_crit[-1], thresh_free_energy)
        loop += 1
        compute_time.append(time.time() - start_time)

    compute_time_mean = compute_time[-1] / loop

    mahalanobis_zero = np.nan
    mahalanobis_cano = np.nan
    mahalanobis_diff = np.nan
    mahalanobis_prod = np.nan
    variation_coeff = np.nan

    if estimate_hrf and not constrained and not normalizing:
        hrf_norm = np.linalg.norm(hrf_mean)
        hrf_mean /= hrf_norm
        hrf_covar /= hrf_norm**2
        sigma_h /= hrf_norm**2
        nrls_mean *= hrf_norm
        nrls_covar *= hrf_norm**2
        nrls_class_mean *= hrf_norm
        nrls_class_var *= hrf_norm**2
        mahalanobis_zero = mahalanobis(hrf_mean, np.zeros_like(hrf_mean),
                                       np.linalg.inv(hrf_covar))
        mahalanobis_cano = mahalanobis(hrf_mean, m_h, np.linalg.inv(hrf_covar))
        mahalanobis_diff = mahalanobis_cano - mahalanobis_zero
        mahalanobis_prod = mahalanobis_cano * mahalanobis_zero
        variation_coeff = np.sqrt((hrf_mean.T.dot(hrf_covar).dot(hrf_mean)) /
                                  (hrf_mean.T.dot(hrf_mean))**2)

    if estimate_hrf and zero_constraint:
        hrf_mean = np.concatenate(([0], hrf_mean, [0]))

        # when using the zero constraint the hrf covariance is fill with
        # arbitrary zeros around the matrix, this is maybe a bad idea if we need
        # it for later computation...
        hrf_covar = np.concatenate((np.zeros(
            (hrf_covar.shape[0], 1)), hrf_covar,
                                    np.zeros((hrf_covar.shape[0], 1))),
                                   axis=1)

        hrf_covar = np.concatenate((np.zeros(
            (1, hrf_covar.shape[1])), hrf_covar,
                                    np.zeros((1, hrf_covar.shape[1]))),
                                   axis=0)

    if estimate_hrf:
        (delay_of_response, delay_of_undershoot, dispersion_of_response,
         dispersion_of_undershoot, ratio_resp_under,
         delay) = vt.fit_hrf_two_gammas(hrf_mean, dt, hrf_duration)
    else:
        (delay_of_response, delay_of_undershoot, dispersion_of_response,
         dispersion_of_undershoot, ratio_resp_under,
         delay) = (None, None, None, None, None, None)

    ppm_a_nrl, ppm_g_nrl = vt.ppms_computation(nrls_mean,
                                               np.diagonal(nrls_covar),
                                               nrls_class_mean,
                                               nrls_class_var,
                                               threshold_a="intersect")

    # Calculate contrast maps and variance
    nb_contrasts = len(contrasts)
    if compute_contrasts and nb_contrasts > 0:
        logger.info('Computing contrasts ...')
        (contrasts_mean, contrasts_var, contrasts_class_mean,
         contrasts_class_var) = vt.contrasts_mean_var_classes(
             contrasts, condition_names, nrls_mean, nrls_covar,
             nrls_class_mean, nrls_class_var, nb_contrasts, nb_classes,
             nb_voxels)

        ppm_a_contrasts, ppm_g_contrasts = vt.ppms_computation(
            contrasts_mean, contrasts_var, contrasts_class_mean,
            contrasts_class_var)
        logger.info('Done computing contrasts.')
    else:
        (contrasts_mean, contrasts_var, contrasts_class_mean,
         contrasts_class_var, ppm_a_contrasts,
         ppm_g_contrasts) = (None, None, None, None, None, None)

    logger.info("Number of iterations to reach criterion: %d", loop)
    logger.info("Computational time = {t[0]:.0f} min {t[1]:.0f} s".format(
        t=divmod(compute_time[-1], 60)))
    logger.debug('nrls_class_mean: %s', nrls_class_mean)
    logger.debug('nrls_class_var: %s', nrls_class_var)
    logger.debug("sigma_H = %s", str(sigma_h))
    logger.debug("beta = %s", str(beta))

    stimulus_induced_signal = vt.computeFit(hrf_mean, nrls_mean, X, nb_voxels,
                                            nb_scans)
    snr = 20 * np.log(
        np.linalg.norm(bold_data.astype(np.float)) / np.linalg.norm(
            (bold_data_drift - stimulus_induced_signal).astype(np.float)))
    snr /= np.log(10.)
    logger.info('SNR comp = %f', snr)

    return (loop, nrls_mean, hrf_mean, hrf_covar, labels_proba, noise_var,
            nrls_class_mean, nrls_class_var, beta, drift_coeffs, drift,
            contrasts_mean, contrasts_var, compute_time[2:], compute_time_mean,
            nrls_covar, stimulus_induced_signal, mahalanobis_zero,
            mahalanobis_cano, mahalanobis_diff, mahalanobis_prod, ppm_a_nrl,
            ppm_g_nrl, ppm_a_contrasts, ppm_g_contrasts, variation_coeff,
            free_energy[1:], free_energy_crit[1:], beta_list[1:],
            delay_of_response, delay_of_undershoot, dispersion_of_response,
            dispersion_of_undershoot, ratio_resp_under, delay)
Ejemplo n.º 6
0
def Main_vbjde_Extension_constrained_stable(graph, Y, Onsets, Thrf, K, TR, beta,
                                            dt, scale=1, estimateSigmaH=True,
                                            sigmaH=0.05, NitMax=-1,
                                            NitMin=1, estimateBeta=True,
                                            PLOT=False, contrasts=[],
                                            computeContrast=False,
                                            gamma_h=0):
    """ Version modified by Lofti from Christine's version """
    logger.info(
        "Fast EM with C extension started ... Here is the stable version !")

    np.random.seed(6537546)

    # Initialize parameters
    S = 100
    if NitMax < 0:
        NitMax = 100
    gamma = 7.5  # 7.5
    gradientStep = 0.003
    MaxItGrad = 200
    Thresh = 1e-5

    # Initialize sizes vectors
    D = np.int(np.ceil(Thrf / dt)) + 1
    M = len(Onsets)
    N = Y.shape[0]
    J = Y.shape[1]
    l = np.int(np.sqrt(J))
    condition_names = []

    # Neighbours
    maxNeighbours = max([len(nl) for nl in graph])
    neighboursIndexes = np.zeros((J, maxNeighbours), dtype=np.int32)
    neighboursIndexes -= 1
    for i in xrange(J):
        neighboursIndexes[i, :len(graph[i])] = graph[i]
    # Conditions
    X = OrderedDict([])
    for condition, Ons in Onsets.iteritems():
        X[condition] = vt.compute_mat_X_2(N, TR, D, dt, Ons)
        condition_names += [condition]
    XX = np.zeros((M, N, D), dtype=np.int32)
    nc = 0
    for condition, Ons in Onsets.iteritems():
        XX[nc, :, :] = X[condition]
        nc += 1
    # Covariance matrix
    order = 2
    D2 = vt.buildFiniteDiffMatrix(order, D)
    R = np.dot(D2, D2) / pow(dt, 2 * order)
    invR = np.linalg.inv(R)
    Det_invR = np.linalg.det(invR)

    Gamma = np.identity(N)
    Det_Gamma = np.linalg.det(Gamma)

    Crit_H = 1
    Crit_Z = 1
    Crit_A = 1
    Crit_AH = 1
    AH = np.zeros((J, M, D), dtype=np.float64)
    AH1 = np.zeros((J, M, D), dtype=np.float64)
    Crit_FreeEnergy = 1
    cTime = []
    cA = []
    cH = []
    cZ = []
    cAH = []

    CONTRAST = np.zeros((J, len(contrasts)), dtype=np.float64)
    CONTRASTVAR = np.zeros((J, len(contrasts)), dtype=np.float64)
    Q_barnCond = np.zeros((M, M, D, D), dtype=np.float64)
    XGamma = np.zeros((M, D, N), dtype=np.float64)
    m1 = 0
    for k1 in X:  # Loop over the M conditions
        m2 = 0
        for k2 in X:
            Q_barnCond[m1, m2, :, :] = np.dot(
                np.dot(X[k1].transpose(), Gamma), X[k2])
            m2 += 1
        XGamma[m1, :, :] = np.dot(X[k1].transpose(), Gamma)
        m1 += 1

    sigma_epsilone = np.ones(J)
    logger.info(
        "Labels are initialized by setting active probabilities to ones ...")
    q_Z = np.zeros((M, K, J), dtype=np.float64)
    q_Z[:, 1, :] = 1
    q_Z1 = np.zeros((M, K, J), dtype=np.float64)
    Z_tilde = q_Z.copy()

    TT, m_h = getCanoHRF(Thrf, dt)  # TODO: check
    m_h = m_h[:D]
    m_H = np.array(m_h).astype(np.float64)
    m_H1 = np.array(m_h)
    sigmaH1 = sigmaH
    Sigma_H = np.ones((D, D), dtype=np.float64)

    Beta = beta * np.ones((M), dtype=np.float64)
    P = vt.PolyMat(N, 4, TR)
    L = vt.polyFit(Y, TR, 4, P)
    PL = np.dot(P, L)
    y_tilde = Y - PL
    Ndrift = L.shape[0]

    sigma_M = np.ones((M, K), dtype=np.float64)
    sigma_M[:, 0] = 0.5
    sigma_M[:, 1] = 0.6
    mu_M = np.zeros((M, K), dtype=np.float64)
    for k in xrange(1, K):
        mu_M[:, k] = 1  # InitMean
    Sigma_A = np.zeros((M, M, J), np.float64)
    for j in xrange(0, J):
        Sigma_A[:, :, j] = 0.01 * np.identity(M)
    m_A = np.zeros((J, M), dtype=np.float64)
    m_A1 = np.zeros((J, M), dtype=np.float64)
    for j in xrange(0, J):
        for m in xrange(0, M):
            for k in xrange(0, K):
                m_A[j, m] += np.random.normal(mu_M[m, k],
                                              np.sqrt(sigma_M[m, k])) * q_Z[m, k, j]
    m_A1 = m_A

    t1 = time.time()

    ##########################################################################
    # VBJDE num. iter. minimum

    ni = 0

    while ((ni < NitMin + 1) or ((Crit_AH > Thresh) and (ni < NitMax))):

        logger.info("------------------------------ Iteration n° " +
                    str(ni + 1) + " ------------------------------")

        #####################
        # EXPECTATION
        #####################

        # A
        logger.info("E A step ...")
        UtilsC.expectation_A(q_Z, mu_M, sigma_M, PL, sigma_epsilone, Gamma,
                             Sigma_H, Y, y_tilde, m_A, m_H, Sigma_A, XX.astype(np.int32), J, D, M, N, K)

        # crit. A
        DIFF = np.reshape(m_A - m_A1, (M * J))
        Crit_A = (
            np.linalg.norm(DIFF) / np.linalg.norm(np.reshape(m_A1, (M * J)))) ** 2
        cA += [Crit_A]
        m_A1[:, :] = m_A[:, :]

        # HRF h
        UtilsC.expectation_H(XGamma, Q_barnCond, sigma_epsilone, Gamma, R, Sigma_H, Y,
                             y_tilde, m_A, m_H, Sigma_A, XX.astype(np.int32), J, D, M, N, scale, sigmaH)
        #m_H[0] = 0
        #m_H[-1] = 0
        # Constrain with optimization strategy
        import cvxpy as cvx
        m, n = Sigma_H.shape
        Sigma_H_inv = np.linalg.inv(Sigma_H)
        zeros_H = np.zeros_like(m_H[:, np.newaxis])
        # Construct the problem. PRIMAL
        h = cvx.Variable(n)
        expression = cvx.quad_form(h - m_H[:, np.newaxis], Sigma_H_inv)
        objective = cvx.Minimize(expression)
        #constraints = [h[0] == 0, h[-1]==0, h >= zeros_H, cvx.square(cvx.norm(h,2))<=1]
        constraints = [h[0] == 0, h[-1] == 0, cvx.square(cvx.norm(h, 2)) <= 1]
        prob = cvx.Problem(objective, constraints)
        result = prob.solve(verbose=0, solver=cvx.CVXOPT)
        # Now we update the mean of h
        m_H_old = m_H
        Sigma_H_old = Sigma_H
        m_H = np.squeeze(np.array((h.value)))
        Sigma_H = np.zeros_like(Sigma_H)
        # and the norm
        h_norm += [np.linalg.norm(m_H)]

        # crit. h
        Crit_H = (np.linalg.norm(m_H - m_H1) / np.linalg.norm(m_H1)) ** 2
        cH += [Crit_H]
        m_H1[:] = m_H[:]

        # crit. AH
        for d in xrange(0, D):
            AH[:, :, d] = m_A[:, :] * m_H[d]
        DIFF = np.reshape(AH - AH1, (M * J * D))
        Crit_AH = (np.linalg.norm(
            DIFF) / (np.linalg.norm(np.reshape(AH1, (M * J * D))) + eps)) ** 2
        cAH += [Crit_AH]
        AH1[:, :, :] = AH[:, :, :]

        # Z labels
        logger.info("E Z step ...")
        UtilsC.expectation_Z(Sigma_A, m_A, sigma_M, Beta, Z_tilde, mu_M,
                             q_Z, neighboursIndexes.astype(np.int32), M, J, K, maxNeighbours)

        # crit. Z
        DIFF = np.reshape(q_Z - q_Z1, (M * K * J))
        Crit_Z = (np.linalg.norm(DIFF) /
                  (np.linalg.norm(np.reshape(q_Z1, (M * K * J))) + eps)) ** 2
        cZ += [Crit_Z]
        q_Z1[:, :, :] = q_Z[:, :, :]

        #####################
        # MAXIMIZATION
        #####################

        # HRF: Sigma_h
        if estimateSigmaH:
            logger.info("M sigma_H step ...")
            if gamma_h > 0:
                sigmaH = vt.maximization_sigmaH_prior(
                    D, Sigma_H, R, m_H, gamma_h)
            else:
                sigmaH = vt.maximization_sigmaH(D, Sigma_H, R, m_H)
            logger.info('sigmaH = %s', str(sigmaH))

        # (mu,sigma)
        logger.info("M (mu,sigma) step ...")
        mu_M, sigma_M = vt.maximization_mu_sigma(
            mu_M, sigma_M, q_Z, m_A, K, M, Sigma_A)

        # Drift L
        UtilsC.maximization_L(
            Y, m_A, m_H, L, P, XX.astype(np.int32), J, D, M, Ndrift, N)
        PL = np.dot(P, L)
        y_tilde = Y - PL

        # Beta
        if estimateBeta:
            logger.info("estimating beta")
            for m in xrange(0, M):
                Beta[m] = UtilsC.maximization_beta(beta, q_Z[m, :, :].astype(np.float64), Z_tilde[m, :, :].astype(
                    np.float64), J, K, neighboursIndexes.astype(np.int32), gamma, maxNeighbours, MaxItGrad, gradientStep)
            logger.info("End estimating beta")
            logger.info(Beta)

        # Sigma noise
        logger.info("M sigma noise step ...")
        UtilsC.maximization_sigma_noise(
            Gamma, PL, sigma_epsilone, Sigma_H, Y, m_A, m_H, Sigma_A, XX.astype(np.int32), J, D, M, N)

        t02 = time.time()
        cTime += [t02 - t1]

    t2 = time.time()

    ##########################################################################
    # PLOTS and SNR computation

    if PLOT and 0:
        font = {'size': 15}
        matplotlib.rc('font', **font)
        savefig('./HRF_Iter_CompMod.png')
        hold(False)
        figure(2)
        plot(cAH[1:-1], 'lightblue')
        hold(True)
        plot(cFE[1:-1], 'm')
        hold(False)
        legend(('CAH', 'CFE'))
        grid(True)
        savefig('./Crit_CompMod.png')
        figure(3)
        plot(FreeEnergyArray)
        grid(True)
        savefig('./FreeEnergy_CompMod.png')

        figure(4)
        for m in xrange(M):
            plot(SUM_q_Z_array[m])
            hold(True)
        hold(False)
        savefig('./Sum_q_Z_Iter_CompMod.png')

        figure(5)
        for m in xrange(M):
            plot(mu1_array[m])
            hold(True)
        hold(False)
        savefig('./mu1_Iter_CompMod.png')

        figure(6)
        plot(h_norm_array)
        savefig('./HRF_Norm_CompMod.png')

        Data_save = xndarray(h_norm_array, ['Iteration'])
        Data_save.save('./HRF_Norm_Comp.nii')

    CompTime = t2 - t1
    cTimeMean = CompTime / ni

    """
    Norm = np.linalg.norm(m_H)
    m_H /= Norm
    Sigma_H /= Norm**2
    sigmaH /= Norm**2
    m_A *= Norm
    Sigma_A *= Norm**2
    mu_M *= Norm
    sigma_M *= Norm**2
    sigma_M = np.sqrt(np.sqrt(sigma_M))
    """
    logger.info("Nb iterations to reach criterion: %d", ni)
    logger.info("Computational time = %s min %s s", str(
        np.int(CompTime // 60)), str(np.int(CompTime % 60)))
    logger.info('mu_M: %f', mu_M)
    logger.info('sigma_M: %f', sigma_M)
    logger.info("sigma_H = %s", str(sigmaH))
    logger.info("Beta = %s", str(Beta))

    StimulusInducedSignal = vt.computeFit(m_H, m_A, X, J, N)
    SNR = 20 * \
        np.log(
            np.linalg.norm(Y) / np.linalg.norm(Y - StimulusInducedSignal - PL))
    SNR /= np.log(10.)
    logger.info('SNR comp = %f', SNR)
    return ni, m_A, m_H, q_Z, sigma_epsilone, mu_M, sigma_M, Beta, L, PL, CONTRAST, CONTRASTVAR, cA[2:], cH[2:], cZ[2:], cAH[2:], cTime[2:], cTimeMean, Sigma_A, StimulusInducedSignal
Ejemplo n.º 7
0
def Main_vbjde_Extension_constrained(graph, Y, Onsets, Thrf, K, TR, beta,
                                     dt, scale=1, estimateSigmaH=True,
                                     sigmaH=0.05, NitMax=-1,
                                     NitMin=1, estimateBeta=True,
                                     PLOT=False, contrasts=[],
                                     computeContrast=False,
                                     gamma_h=0, estimateHRF=True,
                                     TrueHrfFlag=False,
                                     HrfFilename='hrf.nii',
                                     estimateLabels=True,
                                     LabelsFilename='labels.nii',
                                     MFapprox=False, InitVar=0.5,
                                     InitMean=2.0, MiniVEMFlag=False,
                                     NbItMiniVem=5):
    # VBJDE Function for BOLD with contraints

    logger.info("Fast EM with C extension started ...")
    np.random.seed(6537546)

    ##########################################################################
    # INITIALIZATIONS
    # Initialize parameters
    tau1 = 0.0
    tau2 = 0.0
    S = 100
    Init_sigmaH = sigmaH
    Nb2Norm = 1
    NormFlag = False
    if NitMax < 0:
        NitMax = 100
    gamma = 7.5
    #gamma_h = 1000
    gradientStep = 0.003
    MaxItGrad = 200
    Thresh = 1e-5
    Thresh_FreeEnergy = 1e-5
    estimateLabels = True  # WARNING!! They should be estimated

    # Initialize sizes vectors
    D = int(np.ceil(Thrf / dt)) + 1  # D = int(np.ceil(Thrf/dt))
    M = len(Onsets)
    N = Y.shape[0]
    J = Y.shape[1]
    l = int(np.sqrt(J))
    condition_names = []

    # Neighbours
    maxNeighbours = max([len(nl) for nl in graph])
    neighboursIndexes = np.zeros((J, maxNeighbours), dtype=np.int32)
    neighboursIndexes -= 1
    for i in xrange(J):
        neighboursIndexes[i, :len(graph[i])] = graph[i]
    # Conditions
    X = OrderedDict([])
    for condition, Ons in Onsets.iteritems():
        X[condition] = vt.compute_mat_X_2(N, TR, D, dt, Ons)
        condition_names += [condition]
    XX = np.zeros((M, N, D), dtype=np.int32)
    nc = 0
    for condition, Ons in Onsets.iteritems():
        XX[nc, :, :] = X[condition]
        nc += 1
    # Covariance matrix
    order = 2
    D2 = vt.buildFiniteDiffMatrix(order, D)
    R = np.dot(D2, D2) / pow(dt, 2 * order)
    invR = np.linalg.inv(R)
    Det_invR = np.linalg.det(invR)

    Gamma = np.identity(N)
    Det_Gamma = np.linalg.det(Gamma)

    p_Wtilde = np.zeros((M, K), dtype=np.float64)
    p_Wtilde1 = np.zeros((M, K), dtype=np.float64)
    p_Wtilde[:, 1] = 1

    Crit_H = 1
    Crit_Z = 1
    Crit_A = 1
    Crit_AH = 1
    AH = np.zeros((J, M, D), dtype=np.float64)
    AH1 = np.zeros((J, M, D), dtype=np.float64)
    Crit_FreeEnergy = 1

    cA = []
    cH = []
    cZ = []
    cAH = []
    FreeEnergy_Iter = []
    cTime = []
    cFE = []

    SUM_q_Z = [[] for m in xrange(M)]
    mu1 = [[] for m in xrange(M)]
    h_norm = []
    h_norm2 = []

    CONTRAST = np.zeros((J, len(contrasts)), dtype=np.float64)
    CONTRASTVAR = np.zeros((J, len(contrasts)), dtype=np.float64)
    Q_barnCond = np.zeros((M, M, D, D), dtype=np.float64)
    XGamma = np.zeros((M, D, N), dtype=np.float64)
    m1 = 0
    for k1 in X:  # Loop over the M conditions
        m2 = 0
        for k2 in X:
            Q_barnCond[m1, m2, :, :] = np.dot(
                np.dot(X[k1].transpose(), Gamma), X[k2])
            m2 += 1
        XGamma[m1, :, :] = np.dot(X[k1].transpose(), Gamma)
        m1 += 1

    if MiniVEMFlag:
        logger.info("MiniVEM to choose the best initialisation...")
        """InitVar, InitMean, gamma_h = MiniVEM_CompMod(Thrf,TR,dt,beta,Y,K,
                                                     gamma,gradientStep,
                                                     MaxItGrad,D,M,N,J,S,
                                                     maxNeighbours,
                                                     neighboursIndexes,
                                                     XX,X,R,Det_invR,Gamma,
                                                     Det_Gamma,
                                                     scale,Q_barnCond,XGamma,
                                                     NbItMiniVem,
                                                     sigmaH,estimateHRF)"""

        InitVar, InitMean, gamma_h = vt.MiniVEM_CompMod(Thrf, TR, dt, beta, Y, K, gamma, gradientStep, MaxItGrad, D, M, N, J, S, maxNeighbours,
                                                        neighboursIndexes, XX, X, R, Det_invR, Gamma, Det_Gamma, p_Wtilde, scale, Q_barnCond, XGamma, tau1, tau2, NbItMiniVem, sigmaH, estimateHRF)

    sigmaH = Init_sigmaH
    sigma_epsilone = np.ones(J)
    logger.info(
        "Labels are initialized by setting active probabilities to ones ...")
    q_Z = np.zeros((M, K, J), dtype=np.float64)
    q_Z[:, 1, :] = 1
    q_Z1 = np.zeros((M, K, J), dtype=np.float64)
    Z_tilde = q_Z.copy()

    # TT,m_h = getCanoHRF(Thrf-dt,dt) #TODO: check
    TT, m_h = getCanoHRF(Thrf, dt)  # TODO: check
    m_h = m_h[:D]
    m_H = np.array(m_h).astype(np.float64)
    m_H1 = np.array(m_h)
    sigmaH1 = sigmaH
    if estimateHRF:
        Sigma_H = np.ones((D, D), dtype=np.float64)
    else:
        Sigma_H = np.zeros((D, D), dtype=np.float64)

    Beta = beta * np.ones((M), dtype=np.float64)
    P = vt.PolyMat(N, 4, TR)
    L = vt.polyFit(Y, TR, 4, P)
    PL = np.dot(P, L)
    y_tilde = Y - PL
    Ndrift = L.shape[0]

    sigma_M = np.ones((M, K), dtype=np.float64)
    sigma_M[:, 0] = 0.5
    sigma_M[:, 1] = 0.6
    mu_M = np.zeros((M, K), dtype=np.float64)
    for k in xrange(1, K):
        mu_M[:, k] = InitMean
    Sigma_A = np.zeros((M, M, J), np.float64)
    for j in xrange(0, J):
        Sigma_A[:, :, j] = 0.01 * np.identity(M)
    m_A = np.zeros((J, M), dtype=np.float64)
    m_A1 = np.zeros((J, M), dtype=np.float64)
    for j in xrange(0, J):
        for m in xrange(0, M):
            for k in xrange(0, K):
                m_A[j, m] += np.random.normal(mu_M[m, k],
                                              np.sqrt(sigma_M[m, k])) * q_Z[m, k, j]
    m_A1 = m_A

    t1 = time.time()

    ##########################################################################
    # VBJDE num. iter. minimum

    ni = 0

    while ((ni < NitMin) or (((Crit_FreeEnergy > Thresh_FreeEnergy) or (Crit_AH > Thresh)) and (ni < NitMax))):

        logger.info("------------------------------ Iteration n° " +
                    str(ni + 1) + " ------------------------------")

        #####################
        # EXPECTATION
        #####################

        # A
        logger.info("E A step ...")
        UtilsC.expectation_A(q_Z, mu_M, sigma_M, PL, sigma_epsilone, Gamma,
                             Sigma_H, Y, y_tilde, m_A, m_H, Sigma_A, XX.astype(np.int32), J, D, M, N, K)
        val = np.reshape(m_A, (M * J))
        val[np.where((val <= 1e-50) & (val > 0.0))] = 0.0
        val[np.where((val >= -1e-50) & (val < 0.0))] = 0.0

        # crit. A
        DIFF = np.reshape(m_A - m_A1, (M * J))
        # To avoid numerical problems
        DIFF[np.where((DIFF < 1e-50) & (DIFF > 0.0))] = 0.0
        # To avoid numerical problems
        DIFF[np.where((DIFF > -1e-50) & (DIFF < 0.0))] = 0.0
        Crit_A = (
            np.linalg.norm(DIFF) / np.linalg.norm(np.reshape(m_A1, (M * J)))) ** 2
        cA += [Crit_A]
        m_A1[:, :] = m_A[:, :]

        # HRF h
        if estimateHRF:
            ################################
            #  HRF ESTIMATION
            ################################
            UtilsC.expectation_H(XGamma, Q_barnCond, sigma_epsilone, Gamma, R, Sigma_H, Y,
                                 y_tilde, m_A, m_H, Sigma_A, XX.astype(np.int32), J, D, M, N, scale, sigmaH)

            import cvxpy as cvx
            m, n = Sigma_H.shape
            Sigma_H_inv = np.linalg.inv(Sigma_H)
            zeros_H = np.zeros_like(m_H[:, np.newaxis])

            # Construct the problem. PRIMAL
            h = cvx.Variable(n)
            expression = cvx.quad_form(h - m_H[:, np.newaxis], Sigma_H_inv)
            objective = cvx.Minimize(expression)
            #constraints = [h[0] == 0, h[-1]==0, h >= zeros_H, cvx.square(cvx.norm(h,2))<=1]
            constraints = [
                h[0] == 0, h[-1] == 0, cvx.square(cvx.norm(h, 2)) <= 1]
            prob = cvx.Problem(objective, constraints)
            result = prob.solve(verbose=0, solver=cvx.CVXOPT)

            # Now we update the mean of h
            m_H_old = m_H
            Sigma_H_old = Sigma_H
            m_H = np.squeeze(np.array((h.value)))
            Sigma_H = np.zeros_like(Sigma_H)

            h_norm += [np.linalg.norm(m_H)]
            # print 'h_norm = ', h_norm

            # Plotting HRF
            if PLOT and ni >= 0:
                import matplotlib.pyplot as plt
                plt.figure(M + 1)
                plt.plot(m_H)
                plt.hold(True)
        else:
            if TrueHrfFlag:
                #TrueVal, head = read_volume(HrfFilename)
                TrueVal, head = read_volume(HrfFilename)[:, 0, 0, 0]
                print TrueVal
                print TrueVal.shape
                m_H = TrueVal

        # crit. h
        Crit_H = (np.linalg.norm(m_H - m_H1) / np.linalg.norm(m_H1)) ** 2
        cH += [Crit_H]
        m_H1[:] = m_H[:]

        # crit. AH
        for d in xrange(0, D):
            AH[:, :, d] = m_A[:, :] * m_H[d]
        DIFF = np.reshape(AH - AH1, (M * J * D))
        # To avoid numerical problems
        DIFF[np.where((DIFF < 1e-50) & (DIFF > 0.0))] = 0.0
        # To avoid numerical problems
        DIFF[np.where((DIFF > -1e-50) & (DIFF < 0.0))] = 0.0
        if np.linalg.norm(np.reshape(AH1, (M * J * D))) == 0:
            Crit_AH = 1000000000.
        else:
            Crit_AH = (
                np.linalg.norm(DIFF) / np.linalg.norm(np.reshape(AH1, (M * J * D)))) ** 2
        cAH += [Crit_AH]
        AH1[:, :, :] = AH[:, :, :]

        # Z labels
        if estimateLabels:
            logger.info("E Z step ...")
            # WARNING!!! ParsiMod gives better results, but we need the other
            # one.
            if MFapprox:
                UtilsC.expectation_Z(Sigma_A, m_A, sigma_M, Beta, Z_tilde, mu_M, q_Z, neighboursIndexes.astype(
                    np.int32), M, J, K, maxNeighbours)
            if not MFapprox:
                UtilsC.expectation_Z_ParsiMod_RVM_and_CompMod(
                    Sigma_A, m_A, sigma_M, Beta, mu_M, q_Z, neighboursIndexes.astype(np.int32), M, J, K, maxNeighbours)
        else:
            logger.info("Using True Z ...")
            TrueZ = read_volume(LabelsFilename)
            for m in xrange(M):
                q_Z[m, 1, :] = np.reshape(TrueZ[0][:, :, :, m], J)
                q_Z[m, 0, :] = 1 - q_Z[m, 1, :]

        # crit. Z
        val = np.reshape(q_Z, (M * K * J))
        val[np.where((val <= 1e-50) & (val > 0.0))] = 0.0

        DIFF = np.reshape(q_Z - q_Z1, (M * K * J))
        # To avoid numerical problems
        DIFF[np.where((DIFF < 1e-50) & (DIFF > 0.0))] = 0.0
        # To avoid numerical problems
        DIFF[np.where((DIFF > -1e-50) & (DIFF < 0.0))] = 0.0
        if np.linalg.norm(np.reshape(q_Z1, (M * K * J))) == 0:
            Crit_Z = 1000000000.
        else:
            Crit_Z = (
                np.linalg.norm(DIFF) / np.linalg.norm(np.reshape(q_Z1, (M * K * J)))) ** 2
        cZ += [Crit_Z]
        q_Z1 = q_Z

        #####################
        # MAXIMIZATION
        #####################

        # HRF: Sigma_h
        if estimateHRF:
            if estimateSigmaH:
                logger.info("M sigma_H step ...")
                if gamma_h > 0:
                    sigmaH = vt.maximization_sigmaH_prior(
                        D, Sigma_H_old, R, m_H_old, gamma_h)
                else:
                    sigmaH = vt.maximization_sigmaH(D, Sigma_H, R, m_H)
                logger.info('sigmaH = %s', str(sigmaH))

        # (mu,sigma)
        logger.info("M (mu,sigma) step ...")
        mu_M, sigma_M = vt.maximization_mu_sigma(
            mu_M, sigma_M, q_Z, m_A, K, M, Sigma_A)
        for m in xrange(M):
            SUM_q_Z[m] += [sum(q_Z[m, 1, :])]
            mu1[m] += [mu_M[m, 1]]

        # Drift L
        UtilsC.maximization_L(
            Y, m_A, m_H, L, P, XX.astype(np.int32), J, D, M, Ndrift, N)
        PL = np.dot(P, L)
        y_tilde = Y - PL

        # Beta
        if estimateBeta:
            logger.info("estimating beta")
            for m in xrange(0, M):
                if MFapprox:
                    Beta[m] = UtilsC.maximization_beta(beta, q_Z[m, :, :].astype(np.float64), Z_tilde[m, :, :].astype(
                        np.float64), J, K, neighboursIndexes.astype(np.int32), gamma, maxNeighbours, MaxItGrad, gradientStep)
                if not MFapprox:
                    #Beta[m] = UtilsC.maximization_beta(beta,q_Z[m,:,:].astype(np.float64),q_Z[m,:,:].astype(np.float64),J,K,neighboursIndexes.astype(int32),gamma,maxNeighbours,MaxItGrad,gradientStep)
                    Beta[m] = UtilsC.maximization_beta_CB(beta, q_Z[m, :, :].astype(
                        np.float64), J, K, neighboursIndexes.astype(np.int32), gamma, maxNeighbours, MaxItGrad, gradientStep)
            logger.info("End estimating beta")
            logger.info(Beta)

        # Sigma noise
        logger.info("M sigma noise step ...")
        UtilsC.maximization_sigma_noise(
            Gamma, PL, sigma_epsilone, Sigma_H, Y, m_A, m_H, Sigma_A, XX.astype(np.int32), J, D, M, N)

        #### Computing Free Energy ####
        if ni > 0:
            FreeEnergy1 = FreeEnergy

        """FreeEnergy = vt.Compute_FreeEnergy(y_tilde,m_A,Sigma_A,mu_M,sigma_M,
                                           m_H,Sigma_H,R,Det_invR,sigmaH,
                                           p_Wtilde,q_Z,neighboursIndexes,
                                           maxNeighbours,Beta,sigma_epsilone,
                                           XX,Gamma,Det_Gamma,XGamma,J,D,M,
                                           N,K,S,"CompMod")"""
        FreeEnergy = vt.Compute_FreeEnergy(y_tilde, m_A, Sigma_A, mu_M, sigma_M, m_H, Sigma_H, R, Det_invR, sigmaH, p_Wtilde, tau1,
                                           tau2, q_Z, neighboursIndexes, maxNeighbours, Beta, sigma_epsilone, XX, Gamma, Det_Gamma, XGamma, J, D, M, N, K, S, "CompMod")

        if ni > 0:
            Crit_FreeEnergy = (FreeEnergy1 - FreeEnergy) / FreeEnergy1
        FreeEnergy_Iter += [FreeEnergy]
        cFE += [Crit_FreeEnergy]

        # Update index
        ni += 1

        t02 = time.time()
        cTime += [t02 - t1]

    t2 = time.time()

    ##########################################################################
    # PLOTS and SNR computation

    FreeEnergyArray = np.zeros((ni), dtype=np.float64)
    for i in xrange(ni):
        FreeEnergyArray[i] = FreeEnergy_Iter[i]

    SUM_q_Z_array = np.zeros((M, ni), dtype=np.float64)
    mu1_array = np.zeros((M, ni), dtype=np.float64)
    h_norm_array = np.zeros((ni), dtype=np.float64)
    for m in xrange(M):
        for i in xrange(ni):
            SUM_q_Z_array[m, i] = SUM_q_Z[m][i]
            mu1_array[m, i] = mu1[m][i]
            h_norm_array[i] = h_norm[i]

    if PLOT and 0:
        import matplotlib.pyplot as plt
        import matplotlib
        font = {'size': 15}
        matplotlib.rc('font', **font)
        plt.savefig('./HRF_Iter_CompMod.png')
        plt.hold(False)
        plt.figure(2)
        plt.plot(cAH[1:-1], 'lightblue')
        plt.hold(True)
        plt.plot(cFE[1:-1], 'm')
        plt.hold(False)
        #plt.legend( ('CA','CH', 'CZ', 'CAH', 'CFE') )
        plt.legend(('CAH', 'CFE'))
        plt.grid(True)
        plt.savefig('./Crit_CompMod.png')
        plt.figure(3)
        plt.plot(FreeEnergyArray)
        plt.grid(True)
        plt.savefig('./FreeEnergy_CompMod.png')

        plt.figure(4)
        for m in xrange(M):
            plt.plot(SUM_q_Z_array[m])
            plt.hold(True)
        plt.hold(False)
        #plt.legend( ('m=0','m=1', 'm=2', 'm=3') )
        #plt.legend( ('m=0','m=1') )
        plt.savefig('./Sum_q_Z_Iter_CompMod.png')

        plt.figure(5)
        for m in xrange(M):
            plt.plot(mu1_array[m])
            plt.hold(True)
        plt.hold(False)
        plt.savefig('./mu1_Iter_CompMod.png')

        plt.figure(6)
        plt.plot(h_norm_array)
        plt.savefig('./HRF_Norm_CompMod.png')

        Data_save = xndarray(h_norm_array, ['Iteration'])
        Data_save.save('./HRF_Norm_Comp.nii')

    CompTime = t2 - t1
    cTimeMean = CompTime / ni

    sigma_M = np.sqrt(np.sqrt(sigma_M))
    logger.info("Nb iterations to reach criterion: %d", ni)
    logger.info("Computational time = %s min %s s", str(
        int(CompTime // 60)), str(int(CompTime % 60)))
    # print "Computational time = " + str(int( CompTime//60 ) ) + " min " + str(int(CompTime%60)) + " s"
    # print "sigma_H = " + str(sigmaH)
    logger.info('mu_M: %f', mu_M)
    logger.info('sigma_M: %f', sigma_M)
    logger.info("sigma_H = %s" + str(sigmaH))
    logger.info("Beta = %s" + str(Beta))

    StimulusInducedSignal = vt.computeFit(m_H, m_A, X, J, N)
    SNR = 20 * \
        np.log(
            np.linalg.norm(Y) / np.linalg.norm(Y - StimulusInducedSignal - PL))
    SNR /= np.log(10.)
    logger.info("SNR = %d", SNR)
    return ni, m_A, m_H, q_Z, sigma_epsilone, mu_M, sigma_M, Beta, L, PL, CONTRAST, CONTRASTVAR, cA[2:], cH[2:], cZ[2:], cAH[2:], cTime[2:], cTimeMean, Sigma_A, StimulusInducedSignal, FreeEnergyArray