def test_initialise():
    I,J,K,L = 5,3,2,4
    R = numpy.ones((I,J))
    M = numpy.ones((I,J))
    
    lambdaF = 2*numpy.ones((I,K))
    lambdaS = 3*numpy.ones((K,L))
    lambdaG = 4*numpy.ones((J,L))
    alpha, beta = 3, 1
    priors = { 'alpha':alpha, 'beta':beta, 'lambdaF':lambdaF, 'lambdaS':lambdaS, 'lambdaG':lambdaG }
    
    # First do a random initialisation - we can then only check whether values are correctly initialised
    init_S = 'random'
    init_FG = 'random'
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    BNMTF.initialise(init_S,init_FG)
    
    assert BNMTF.tau >= 0.0
    for i,k in itertools.product(xrange(0,I),xrange(0,K)):
        assert BNMTF.F[i,k] >= 0.0
    for k,l in itertools.product(xrange(0,K),xrange(0,L)):
        assert BNMTF.S[k,l] >= 0.0
    for j,l in itertools.product(xrange(0,J),xrange(0,L)):
        assert BNMTF.G[j,l] >= 0.0
        
    # Initialisation of S using random draws from prior
    init_S, init_FG = 'random', 'exp'
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    BNMTF.initialise(init_S,init_FG)
    
    for i,k in itertools.product(xrange(0,I),xrange(0,K)):
        assert BNMTF.F[i,k] == 1./lambdaF[i,k]
    for k,l in itertools.product(xrange(0,K),xrange(0,L)):
        assert BNMTF.S[k,l] != 1./lambdaS[k,l] # test whether we overwrote the expectation
    for j,l in itertools.product(xrange(0,J),xrange(0,L)):
        assert BNMTF.G[j,l] == 1./lambdaG[j,l]
    
    # Initialisation of F and G using random draws from prior
    init_S, init_FG = 'exp', 'random'
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    BNMTF.initialise(init_S,init_FG)
    
    for i,k in itertools.product(xrange(0,I),xrange(0,K)):
        assert BNMTF.F[i,k] != 1./lambdaF[i,k] # test whether we overwrote the expectation
    for k,l in itertools.product(xrange(0,K),xrange(0,L)):
        assert BNMTF.S[k,l] == 1./lambdaS[k,l]
    for j,l in itertools.product(xrange(0,J),xrange(0,L)):
        assert BNMTF.G[j,l] != 1./lambdaG[j,l] # test whether we overwrote the expectation
        
    # Initialisation of F and G using Kmeans
    init_S, init_FG = 'exp', 'kmeans'
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    BNMTF.initialise(init_S,init_FG)
    
    for i,k in itertools.product(xrange(0,I),xrange(0,K)):
        assert BNMTF.F[i,k] == 0.2 or BNMTF.F[i,k] == 1.2
    for j,l in itertools.product(xrange(0,J),xrange(0,L)):
        assert BNMTF.G[j,l] == 0.2 or BNMTF.G[j,l] == 1.2
    for k,l in itertools.product(xrange(0,K),xrange(0,L)):
        assert BNMTF.S[k,l] == 1./lambdaS[k,l]
Пример #2
0
def test_beta_s():
    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.initialise(init_S, init_FG)
    BNMTF.tau = 3.
    beta_s = beta + .5 * (12 * (
        11. / 15.)**2)  #F*S = [[1/6+1/6=1/3,..]], F*S*G^T = [[1/15*4=4/15,..]]
    assert abs(BNMTF.beta_s() - beta_s) < 0.00000000000001
def test_log_likelihood():
    R = numpy.array([[1,2],[3,4]],dtype=float)
    M = numpy.array([[1,1],[0,1]])
    I, J, K, L = 2, 2, 3, 4
    lambdaF = 2*numpy.ones((I,K))
    lambdaS = 3*numpy.ones((K,L))
    lambdaG = 4*numpy.ones((J,L))
    alpha, beta = 3, 1
    priors = { 'alpha':alpha, 'beta':beta, 'lambdaF':lambdaF, 'lambdaS':lambdaS, 'lambdaG':lambdaG }
    
    iterations = 10
    burnin, thinning = 4, 2
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    BNMTF.all_F = [numpy.ones((I,K)) for i in range(0,iterations)]
    BNMTF.all_S = [2*numpy.ones((K,L)) for i in range(0,iterations)]
    BNMTF.all_G = [3*numpy.ones((J,L)) for i in range(0,iterations)]
    BNMTF.all_tau = [3. for i in range(0,iterations)]
    # expU*expV.T = [[72.]]
    
    log_likelihood = 3./2.*(math.log(3)-math.log(2*math.pi)) - 3./2. * (71**2 + 70**2 + 68**2)
    AIC = -2*log_likelihood + 2*(2*3+3*4+2*4)
    BIC = -2*log_likelihood + (2*3+3*4+2*4)*math.log(3)
    MSE = (71**2+70**2+68**2)/3.
    
    assert log_likelihood == BNMTF.quality('loglikelihood',burnin,thinning)
    assert AIC == BNMTF.quality('AIC',burnin,thinning)
    assert BIC == BNMTF.quality('BIC',burnin,thinning)
    assert MSE == BNMTF.quality('MSE',burnin,thinning)
    with pytest.raises(AssertionError) as error:
        BNMTF.quality('FAIL',burnin,thinning)
    assert str(error.value) == "Unrecognised metric for model quality: FAIL."
def test_run():
    I,J,K,L = 10,5,3,2
    R = numpy.ones((I,J))
    M = numpy.ones((I,J))
    M[0,0], M[2,2], M[3,1] = 0, 0, 0
    
    lambdaF = 2*numpy.ones((I,K))
    lambdaS = 3*numpy.ones((K,L))
    lambdaG = 4*numpy.ones((J,L))
    alpha, beta = 3, 1
    priors = { 'alpha':alpha, 'beta':beta, 'lambdaF':lambdaF, 'lambdaS':lambdaS, 'lambdaG':lambdaG }
    init = 'exp' #F=1/2,S=1/3,G=1/4
    
    F_prior = numpy.ones((I,K))/2.
    S_prior = numpy.ones((K,L))/3.
    G_prior = numpy.ones((J,L))/4.
    
    iterations = 15
    
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    BNMTF.initialise(init)
    (Fs,Ss,Gs,taus) = BNMTF.run(iterations)
    
    assert BNMTF.all_F.shape == (iterations,I,K)
    assert BNMTF.all_S.shape == (iterations,K,L)
    assert BNMTF.all_G.shape == (iterations,J,L)
    assert BNMTF.all_tau.shape == (iterations,)
    
    for i,k in itertools.product(xrange(0,I),xrange(0,K)):
        assert Fs[0,i,k] != F_prior[i,k]
    for k,l in itertools.product(xrange(0,K),xrange(0,L)):
        assert Ss[0,k,l] != S_prior[k,l]
    for j,l in itertools.product(xrange(0,J),xrange(0,L)):
        assert Gs[0,j,l] != G_prior[j,l]
    assert taus[1] != alpha/float(beta)
def test_approx_expectation():
    burn_in = 2
    thinning = 3 # so index 2,5,8 -> m=3,m=6,m=9
    (I,J,K,L) = (5,3,2,4)
    Fs = [numpy.ones((I,K)) * 3*m**2 for m in range(1,10+1)] 
    Ss = [numpy.ones((K,L)) * 2*m**2 for m in range(1,10+1)]
    Gs = [numpy.ones((J,L)) * 1*m**2 for m in range(1,10+1)] #first is 1's, second is 4's, third is 9's, etc.
    taus = [m**2 for m in range(1,10+1)]
    
    expected_exp_tau = (9.+36.+81.)/3.
    expected_exp_F = numpy.array([[9.+36.+81.,9.+36.+81.],[9.+36.+81.,9.+36.+81.],[9.+36.+81.,9.+36.+81.],[9.+36.+81.,9.+36.+81.],[9.+36.+81.,9.+36.+81.]])
    expected_exp_S = numpy.array([[(9.+36.+81.)*(2./3.),(9.+36.+81.)*(2./3.),(9.+36.+81.)*(2./3.),(9.+36.+81.)*(2./3.)],[(9.+36.+81.)*(2./3.),(9.+36.+81.)*(2./3.),(9.+36.+81.)*(2./3.),(9.+36.+81.)*(2./3.)]])
    expected_exp_G = numpy.array([[(9.+36.+81.)*(1./3.),(9.+36.+81.)*(1./3.),(9.+36.+81.)*(1./3.),(9.+36.+81.)*(1./3.)],[(9.+36.+81.)*(1./3.),(9.+36.+81.)*(1./3.),(9.+36.+81.)*(1./3.),(9.+36.+81.)*(1./3.)],[(9.+36.+81.)*(1./3.),(9.+36.+81.)*(1./3.),(9.+36.+81.)*(1./3.),(9.+36.+81.)*(1./3.)]])
    
    R = numpy.ones((I,J))
    M = numpy.ones((I,J))
    lambdaF = 2*numpy.ones((I,K))
    lambdaS = 3*numpy.ones((K,L))
    lambdaG = 4*numpy.ones((J,L))
    alpha, beta = 3, 1
    priors = { 'alpha':alpha, 'beta':beta, 'lambdaF':lambdaF, 'lambdaS':lambdaS, 'lambdaG':lambdaG }
    
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    BNMTF.all_F = Fs
    BNMTF.all_S = Ss
    BNMTF.all_G = Gs
    BNMTF.all_tau = taus
    (exp_F, exp_S, exp_G, exp_tau) = BNMTF.approx_expectation(burn_in,thinning)
    
    assert expected_exp_tau == exp_tau
    assert numpy.array_equal(expected_exp_F,exp_F)
    assert numpy.array_equal(expected_exp_S,exp_S)
    assert numpy.array_equal(expected_exp_G,exp_G)
Пример #6
0
def test_compute_statistics():
    R = numpy.array([[1, 2], [3, 4]], dtype=float)
    M = numpy.array([[1, 1], [0, 1]])
    I, J, K, L = 2, 2, 3, 4
    lambdaF = 2 * numpy.ones((I, K))
    lambdaS = 3 * numpy.ones((K, L))
    lambdaG = 4 * numpy.ones((J, L))
    alpha, beta = 3, 1
    priors = {
        'alpha': alpha,
        'beta': beta,
        'lambdaF': lambdaF,
        'lambdaS': lambdaS,
        'lambdaG': lambdaG
    }

    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)

    R_pred = numpy.array([[500, 550], [1220, 1342]], dtype=float)
    M_pred = numpy.array([[0, 0], [1, 1]])

    MSE_pred = (1217**2 + 1338**2) / 2.0
    R2_pred = 1. - (1217**2 + 1338**2) / (0.5**2 + 0.5**2)  #mean=3.5
    Rp_pred = 61. / (math.sqrt(.5) * math.sqrt(7442.)
                     )  #mean=3.5,var=0.5,mean_pred=1281,var_pred=7442,cov=61

    assert MSE_pred == BNMTF.compute_MSE(M_pred, R, R_pred)
    assert R2_pred == BNMTF.compute_R2(M_pred, R, R_pred)
    assert Rp_pred == BNMTF.compute_Rp(M_pred, R, R_pred)
def test_tauG():
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    BNMTF.initialise(init_S,init_FG)
    BNMTF.tau = 3.
    # F*S = [[1/3]], (F*S)^2 = [[1/9]], sum_i F*S = [[4/9]]
    tauG = 3.*numpy.array([[4./9.,4./9.,4./9.,4./9.],[4./9.,4./9.,4./9.,4./9.],[4./9.,4./9.,4./9.,4./9.]])
    for j,l in itertools.product(xrange(0,J),xrange(0,L)):
        assert BNMTF.tauG(l)[j] == tauG[j,l]
def test_tauS():
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    BNMTF.initialise(init_S,init_FG)
    BNMTF.tau = 3.
    # F outer G = [[1/10]], (F outer G)^2 = [[1/100]], sum (F outer G)^2 = [[12/100]]
    tauS = 3.*numpy.array([[3./25.,3./25.,3./25.,3./25.],[3./25.,3./25.,3./25.,3./25.]])
    for k,l in itertools.product(xrange(0,K),xrange(0,L)):
        assert abs(BNMTF.tauS(k,l) - tauS[k,l]) < 0.000000000000001
def test_tauF():
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    BNMTF.initialise(init_S,init_FG)
    BNMTF.tau = 3.
    # S*G.T = [[4/15]], (S*G.T)^2 = [[16/225]], sum_j S*G.T = [[32/225,32/225],[48/225,48/225],[32/225,32/225],[32/225,32/225],[48/225,48/225]]
    tauF = 3.*numpy.array([[32./225.,32./225.],[48./225.,48./225.],[32./225.,32./225.],[32./225.,32./225.],[48./225.,48./225.]])
    for i,k in itertools.product(xrange(0,I),xrange(0,K)):
        assert abs(BNMTF.tauF(k)[i] - tauF[i,k]) < 0.000000000000001
Пример #10
0
def test_tauS():
    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.initialise(init_S, init_FG)
    BNMTF.tau = 3.
    # F outer G = [[1/10]], (F outer G)^2 = [[1/100]], sum (F outer G)^2 = [[12/100]]
    tauS = 3. * numpy.array([[3. / 25., 3. / 25., 3. / 25., 3. / 25.],
                             [3. / 25., 3. / 25., 3. / 25., 3. / 25.]])
    for k, l in itertools.product(xrange(0, K), xrange(0, L)):
        assert abs(BNMTF.tauS(k, l) - tauS[k, l]) < 0.000000000000001
Пример #11
0
def test_predict():
    burn_in = 2
    thinning = 3  # so index 2,5,8 -> m=3,m=6,m=9
    (I, J, K, L) = (5, 3, 2, 4)
    Fs = [numpy.ones((I, K)) * 3 * m**2 for m in range(1, 10 + 1)]
    Ss = [numpy.ones((K, L)) * 2 * m**2 for m in range(1, 10 + 1)]
    Gs = [numpy.ones((J, L)) * 1 * m**2 for m in range(1, 10 + 1)
          ]  #first is 1's, second is 4's, third is 9's, etc.
    Fs[2][
        0,
        0] = 24  #instead of 27 - to ensure we do not get 0 variance in our predictions
    taus = [m**2 for m in range(1, 10 + 1)]

    R = numpy.array(
        [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]],
        dtype=float)
    M = numpy.ones((I, J))
    lambdaF = 2 * numpy.ones((I, K))
    lambdaS = 3 * numpy.ones((K, L))
    lambdaG = 5 * numpy.ones((J, L))
    alpha, beta = 3, 1
    priors = {
        'alpha': alpha,
        'beta': beta,
        'lambdaF': lambdaF,
        'lambdaS': lambdaS,
        'lambdaG': lambdaG
    }

    #expected_exp_F = numpy.array([[125.,126.],[126.,126.],[126.,126.],[126.,126.],[126.,126.]])
    #expected_exp_S = numpy.array([[84.,84.,84.,84.],[84.,84.,84.,84.]])
    #expected_exp_G = numpy.array([[42.,42.,42.,42.],[42.,42.,42.,42.],[42.,42.,42.,42.]])
    #R_pred = numpy.array([[ 3542112.,  3542112.,  3542112.],[ 3556224.,  3556224.,  3556224.],[ 3556224.,  3556224.,  3556224.],[ 3556224.,  3556224.,  3556224.],[ 3556224.,  3556224.,  3556224.]])

    M_test = numpy.array(
        [[0, 0, 1], [0, 1, 0], [0, 0, 0], [1, 1, 0],
         [0, 0, 0]])  #R->3,5,10,11, R_pred->3542112,3556224,3556224,3556224
    MSE = ((3. - 3542112.)**2 + (5. - 3556224.)**2 + (10. - 3556224.)**2 +
           (11. - 3556224.)**2) / 4.
    R2 = 1. - ((3. - 3542112.)**2 + (5. - 3556224.)**2 + (10. - 3556224.)**2 +
               (11. - 3556224.)**2) / (4.25**2 + 2.25**2 + 2.75**2 + 3.75**2
                                       )  #mean=7.25
    Rp = 357. / (
        math.sqrt(44.75) * math.sqrt(5292.)
    )  #mean=7.25,var=44.75, mean_pred=3552696,var_pred=5292, corr=(-4.25*-63 + -2.25*21 + 2.75*21 + 3.75*21)

    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.all_F = Fs
    BNMTF.all_S = Ss
    BNMTF.all_G = Gs
    BNMTF.all_tau = taus
    performances = BNMTF.predict(M_test, burn_in, thinning)

    assert performances['MSE'] == MSE
    assert performances['R^2'] == R2
    assert performances['Rp'] == Rp
Пример #12
0
def test_tauG():
    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.initialise(init_S, init_FG)
    BNMTF.tau = 3.
    # F*S = [[1/3]], (F*S)^2 = [[1/9]], sum_i F*S = [[4/9]]
    tauG = 3. * numpy.array([[4. / 9., 4. / 9., 4. / 9., 4. / 9.],
                             [4. / 9., 4. / 9., 4. / 9., 4. / 9.],
                             [4. / 9., 4. / 9., 4. / 9., 4. / 9.]])
    for j, l in itertools.product(xrange(0, J), xrange(0, L)):
        assert BNMTF.tauG(l)[j] == tauG[j, l]
def test_muG():
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    BNMTF.initialise(init_S,init_FG)
    BNMTF.tau = 3.
    tauG = 3.*numpy.array([[4./9.,4./9.,4./9.,4./9.],[4./9.,4./9.,4./9.,4./9.],[4./9.,4./9.,4./9.,4./9.]])
    # Rij - Fi*S*Gj + Gjl*(Fi*Sl)) = 11/15 + 1/5 * 1/3 = 12/15 = 4/5
    # (Rij - Fi*S*Gj + Gjl*(Fi*Sl)) * (Fi*Sl) = 4/5 * 1/3 = 4/15
    muG = 1./tauG * ( 3. * numpy.array([[4.*4./15.,4.*4./15.,4.*4./15.,4.*4./15.],[4.*4./15.,4.*4./15.,4.*4./15.,4.*4./15.],[4.*4./15.,4.*4./15.,4.*4./15.,4.*4./15.]]) - lambdaG )
    for j,l in itertools.product(xrange(0,J),xrange(0,L)):
        assert abs(BNMTF.muG(tauG[:,l],l)[j] - muG[j,l]) < 0.000000000000001
def test_muS():
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    BNMTF.initialise(init_S,init_FG)
    BNMTF.tau = 3.
    tauS = 3.*numpy.array([[3./25.,3./25.,3./25.,3./25.],[3./25.,3./25.,3./25.,3./25.]])
    # Rij - Fi*S*Gj + Fik*Skl*Gjk = 11/15 + 1/2*1/3*1/5 = 23/30
    # (Rij - Fi*S*Gj + Fik*Skl*Gjk) * Fik*Gjk = 23/30 * 1/10 = 23/300
    muS = 1./tauS * ( 3. * numpy.array([[12*23./300.,12*23./300.,12*23./300.,12*23./300.],[12*23./300.,12*23./300.,12*23./300.,12*23./300.]]) - lambdaS )
    for k,l in itertools.product(xrange(0,K),xrange(0,L)):
        assert abs(BNMTF.muS(tauS[k,l],k,l) - muS[k,l]) < 0.000000000000001
def test_muF():
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    BNMTF.initialise(init_S,init_FG)
    BNMTF.tau = 3.
    tauF = 3.*numpy.array([[32./225.,32./225.],[48./225.,48./225.],[32./225.,32./225.],[32./225.,32./225.],[48./225.,48./225.]])
    # Rij - Fi*S*Gj + Fik(Sk*Gj) = 11/15 + 1/2 * 4/15 = 13/15
    # (Rij - Fi*S*Gj + Fik(Sk*Gj)) * (Sk*Gj) = 13/15 * 4/15 = 52/225
    muF = 1./tauF * ( 3. * numpy.array([[2*(52./225.),2*(52./225.)],[3*(52./225.),3*(52./225.)],[2*(52./225.),2*(52./225.)],[2*(52./225.),2*(52./225.)],[3*(52./225.),3*(52./225.)]]) - lambdaF )
    for i,k in itertools.product(xrange(0,I),xrange(0,K)):
        assert abs(BNMTF.muF(tauF[:,k],k)[i] - muF[i,k]) < 0.000000000000001
Пример #16
0
def test_tauF():
    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.initialise(init_S, init_FG)
    BNMTF.tau = 3.
    # S*G.T = [[4/15]], (S*G.T)^2 = [[16/225]], sum_j S*G.T = [[32/225,32/225],[48/225,48/225],[32/225,32/225],[32/225,32/225],[48/225,48/225]]
    tauF = 3. * numpy.array([[32. / 225., 32. / 225.], [
        48. / 225., 48. / 225.
    ], [32. / 225., 32. / 225.], [32. / 225., 32. / 225.],
                             [48. / 225., 48. / 225.]])
    for i, k in itertools.product(xrange(0, I), xrange(0, K)):
        assert abs(BNMTF.tauF(k)[i] - tauF[i, k]) < 0.000000000000001
Пример #17
0
def test_muS():
    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.initialise(init_S, init_FG)
    BNMTF.tau = 3.
    tauS = 3. * numpy.array([[3. / 25., 3. / 25., 3. / 25., 3. / 25.],
                             [3. / 25., 3. / 25., 3. / 25., 3. / 25.]])
    # Rij - Fi*S*Gj + Fik*Skl*Gjk = 11/15 + 1/2*1/3*1/5 = 23/30
    # (Rij - Fi*S*Gj + Fik*Skl*Gjk) * Fik*Gjk = 23/30 * 1/10 = 23/300
    muS = 1. / tauS * (3. * numpy.array([[
        12 * 23. / 300., 12 * 23. / 300., 12 * 23. / 300., 12 * 23. / 300.
    ], [12 * 23. / 300., 12 * 23. / 300., 12 * 23. / 300., 12 * 23. / 300.]]) -
                       lambdaS)
    for k, l in itertools.product(xrange(0, K), xrange(0, L)):
        assert abs(BNMTF.muS(tauS[k, l], k, l) - muS[k, l]) < 0.000000000000001
Пример #18
0
def test_muG():
    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.initialise(init_S, init_FG)
    BNMTF.tau = 3.
    tauG = 3. * numpy.array([[4. / 9., 4. / 9., 4. / 9., 4. / 9.],
                             [4. / 9., 4. / 9., 4. / 9., 4. / 9.],
                             [4. / 9., 4. / 9., 4. / 9., 4. / 9.]])
    # Rij - Fi*S*Gj + Gjl*(Fi*Sl)) = 11/15 + 1/5 * 1/3 = 12/15 = 4/5
    # (Rij - Fi*S*Gj + Gjl*(Fi*Sl)) * (Fi*Sl) = 4/5 * 1/3 = 4/15
    muG = 1. / tauG * (3. * numpy.array(
        [[4. * 4. / 15., 4. * 4. / 15., 4. * 4. / 15., 4. * 4. / 15.],
         [4. * 4. / 15., 4. * 4. / 15., 4. * 4. / 15., 4. * 4. / 15.],
         [4. * 4. / 15., 4. * 4. / 15., 4. * 4. / 15., 4. * 4. / 15.]]) -
                       lambdaG)
    for j, l in itertools.product(xrange(0, J), xrange(0, L)):
        assert abs(BNMTF.muG(tauG[:, l], l)[j] - muG[j, l]) < 0.000000000000001
Пример #19
0
def test_muF():
    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.initialise(init_S, init_FG)
    BNMTF.tau = 3.
    tauF = 3. * numpy.array([[32. / 225., 32. / 225.], [
        48. / 225., 48. / 225.
    ], [32. / 225., 32. / 225.], [32. / 225., 32. / 225.],
                             [48. / 225., 48. / 225.]])
    # Rij - Fi*S*Gj + Fik(Sk*Gj) = 11/15 + 1/2 * 4/15 = 13/15
    # (Rij - Fi*S*Gj + Fik(Sk*Gj)) * (Sk*Gj) = 13/15 * 4/15 = 52/225
    muF = 1. / tauF * (3. * numpy.array(
        [[2 * (52. / 225.), 2 *
          (52. / 225.)], [3 * (52. / 225.), 3 *
                          (52. / 225.)], [2 * (52. / 225.), 2 * (52. / 225.)],
         [2 * (52. / 225.), 2 *
          (52. / 225.)], [3 * (52. / 225.), 3 * (52. / 225.)]]) - lambdaF)
    for i, k in itertools.product(xrange(0, I), xrange(0, K)):
        assert abs(BNMTF.muF(tauF[:, k], k)[i] - muF[i, k]) < 0.000000000000001
Пример #20
0
def test_run():
    I, J, K, L = 10, 5, 3, 2
    R = numpy.ones((I, J))
    M = numpy.ones((I, J))
    M[0, 0], M[2, 2], M[3, 1] = 0, 0, 0

    lambdaF = 2 * numpy.ones((I, K))
    lambdaS = 3 * numpy.ones((K, L))
    lambdaG = 4 * numpy.ones((J, L))
    alpha, beta = 3, 1
    priors = {
        'alpha': alpha,
        'beta': beta,
        'lambdaF': lambdaF,
        'lambdaS': lambdaS,
        'lambdaG': lambdaG
    }
    init = 'exp'  #F=1/2,S=1/3,G=1/4

    F_prior = numpy.ones((I, K)) / 2.
    S_prior = numpy.ones((K, L)) / 3.
    G_prior = numpy.ones((J, L)) / 4.

    iterations = 15

    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.initialise(init)
    (Fs, Ss, Gs, taus) = BNMTF.run(iterations)

    assert BNMTF.all_F.shape == (iterations, I, K)
    assert BNMTF.all_S.shape == (iterations, K, L)
    assert BNMTF.all_G.shape == (iterations, J, L)
    assert BNMTF.all_tau.shape == (iterations, )

    for i, k in itertools.product(xrange(0, I), xrange(0, K)):
        assert Fs[0, i, k] != F_prior[i, k]
    for k, l in itertools.product(xrange(0, K), xrange(0, L)):
        assert Ss[0, k, l] != S_prior[k, l]
    for j, l in itertools.product(xrange(0, J), xrange(0, L)):
        assert Gs[0, j, l] != G_prior[j, l]
    assert taus[1] != alpha / float(beta)
def test_compute_statistics():
    R = numpy.array([[1,2],[3,4]],dtype=float)
    M = numpy.array([[1,1],[0,1]])
    I, J, K, L = 2, 2, 3, 4
    lambdaF = 2*numpy.ones((I,K))
    lambdaS = 3*numpy.ones((K,L))
    lambdaG = 4*numpy.ones((J,L))
    alpha, beta = 3, 1
    priors = { 'alpha':alpha, 'beta':beta, 'lambdaF':lambdaF, 'lambdaS':lambdaS, 'lambdaG':lambdaG }
    
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    
    R_pred = numpy.array([[500,550],[1220,1342]],dtype=float)
    M_pred = numpy.array([[0,0],[1,1]])
    
    MSE_pred = (1217**2 + 1338**2) / 2.0
    R2_pred = 1. - (1217**2+1338**2)/(0.5**2+0.5**2) #mean=3.5
    Rp_pred = 61. / ( math.sqrt(.5) * math.sqrt(7442.) ) #mean=3.5,var=0.5,mean_pred=1281,var_pred=7442,cov=61
    
    assert MSE_pred == BNMTF.compute_MSE(M_pred,R,R_pred)
    assert R2_pred == BNMTF.compute_R2(M_pred,R,R_pred)
    assert Rp_pred == BNMTF.compute_Rp(M_pred,R,R_pred)
def test_predict():
    burn_in = 2
    thinning = 3 # so index 2,5,8 -> m=3,m=6,m=9
    (I,J,K,L) = (5,3,2,4)
    Fs = [numpy.ones((I,K)) * 3*m**2 for m in range(1,10+1)] 
    Ss = [numpy.ones((K,L)) * 2*m**2 for m in range(1,10+1)]
    Gs = [numpy.ones((J,L)) * 1*m**2 for m in range(1,10+1)] #first is 1's, second is 4's, third is 9's, etc.
    Fs[2][0,0] = 24 #instead of 27 - to ensure we do not get 0 variance in our predictions
    taus = [m**2 for m in range(1,10+1)]
    
    R = numpy.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[13,14,15]],dtype=float)
    M = numpy.ones((I,J))
    lambdaF = 2*numpy.ones((I,K))
    lambdaS = 3*numpy.ones((K,L))
    lambdaG = 5*numpy.ones((J,L))
    alpha, beta = 3, 1
    priors = { 'alpha':alpha, 'beta':beta, 'lambdaF':lambdaF, 'lambdaS':lambdaS, 'lambdaG':lambdaG }
    
    #expected_exp_F = numpy.array([[125.,126.],[126.,126.],[126.,126.],[126.,126.],[126.,126.]])
    #expected_exp_S = numpy.array([[84.,84.,84.,84.],[84.,84.,84.,84.]])
    #expected_exp_G = numpy.array([[42.,42.,42.,42.],[42.,42.,42.,42.],[42.,42.,42.,42.]])
    #R_pred = numpy.array([[ 3542112.,  3542112.,  3542112.],[ 3556224.,  3556224.,  3556224.],[ 3556224.,  3556224.,  3556224.],[ 3556224.,  3556224.,  3556224.],[ 3556224.,  3556224.,  3556224.]])
       
    M_test = numpy.array([[0,0,1],[0,1,0],[0,0,0],[1,1,0],[0,0,0]]) #R->3,5,10,11, R_pred->3542112,3556224,3556224,3556224
    MSE = ((3.-3542112.)**2 + (5.-3556224.)**2 + (10.-3556224.)**2 + (11.-3556224.)**2) / 4.
    R2 = 1. - ((3.-3542112.)**2 + (5.-3556224.)**2 + (10.-3556224.)**2 + (11.-3556224.)**2) / (4.25**2+2.25**2+2.75**2+3.75**2) #mean=7.25
    Rp = 357. / ( math.sqrt(44.75) * math.sqrt(5292.) ) #mean=7.25,var=44.75, mean_pred=3552696,var_pred=5292, corr=(-4.25*-63 + -2.25*21 + 2.75*21 + 3.75*21)
    
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    BNMTF.all_F = Fs
    BNMTF.all_S = Ss
    BNMTF.all_G = Gs
    BNMTF.all_tau = taus
    performances = BNMTF.predict(M_test,burn_in,thinning)
    
    assert performances['MSE'] == MSE
    assert performances['R^2'] == R2
    assert performances['Rp'] == Rp
Пример #23
0
def test_log_likelihood():
    R = numpy.array([[1, 2], [3, 4]], dtype=float)
    M = numpy.array([[1, 1], [0, 1]])
    I, J, K, L = 2, 2, 3, 4
    lambdaF = 2 * numpy.ones((I, K))
    lambdaS = 3 * numpy.ones((K, L))
    lambdaG = 4 * numpy.ones((J, L))
    alpha, beta = 3, 1
    priors = {
        'alpha': alpha,
        'beta': beta,
        'lambdaF': lambdaF,
        'lambdaS': lambdaS,
        'lambdaG': lambdaG
    }

    iterations = 10
    burnin, thinning = 4, 2
    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.all_F = [numpy.ones((I, K)) for i in range(0, iterations)]
    BNMTF.all_S = [2 * numpy.ones((K, L)) for i in range(0, iterations)]
    BNMTF.all_G = [3 * numpy.ones((J, L)) for i in range(0, iterations)]
    BNMTF.all_tau = [3. for i in range(0, iterations)]
    # expU*expV.T = [[72.]]

    log_likelihood = 3. / 2. * (math.log(3) - math.log(
        2 * math.pi)) - 3. / 2. * (71**2 + 70**2 + 68**2)
    AIC = -2 * log_likelihood + 2 * (2 * 3 + 3 * 4 + 2 * 4)
    BIC = -2 * log_likelihood + (2 * 3 + 3 * 4 + 2 * 4) * math.log(3)
    MSE = (71**2 + 70**2 + 68**2) / 3.

    assert log_likelihood == BNMTF.quality('loglikelihood', burnin, thinning)
    assert AIC == BNMTF.quality('AIC', burnin, thinning)
    assert BIC == BNMTF.quality('BIC', burnin, thinning)
    assert MSE == BNMTF.quality('MSE', burnin, thinning)
    with pytest.raises(AssertionError) as error:
        BNMTF.quality('FAIL', burnin, thinning)
    assert str(error.value) == "Unrecognised metric for model quality: FAIL."
Пример #24
0
def test_initialise():
    I, J, K, L = 5, 3, 2, 4
    R = numpy.ones((I, J))
    M = numpy.ones((I, J))

    lambdaF = 2 * numpy.ones((I, K))
    lambdaS = 3 * numpy.ones((K, L))
    lambdaG = 4 * numpy.ones((J, L))
    alpha, beta = 3, 1
    priors = {
        'alpha': alpha,
        'beta': beta,
        'lambdaF': lambdaF,
        'lambdaS': lambdaS,
        'lambdaG': lambdaG
    }

    # First do a random initialisation - we can then only check whether values are correctly initialised
    init_S = 'random'
    init_FG = 'random'
    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.initialise(init_S, init_FG)

    assert BNMTF.tau >= 0.0
    for i, k in itertools.product(xrange(0, I), xrange(0, K)):
        assert BNMTF.F[i, k] >= 0.0
    for k, l in itertools.product(xrange(0, K), xrange(0, L)):
        assert BNMTF.S[k, l] >= 0.0
    for j, l in itertools.product(xrange(0, J), xrange(0, L)):
        assert BNMTF.G[j, l] >= 0.0

    # Initialisation of S using random draws from prior
    init_S, init_FG = 'random', 'exp'
    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.initialise(init_S, init_FG)

    for i, k in itertools.product(xrange(0, I), xrange(0, K)):
        assert BNMTF.F[i, k] == 1. / lambdaF[i, k]
    for k, l in itertools.product(xrange(0, K), xrange(0, L)):
        assert BNMTF.S[k, l] != 1. / lambdaS[
            k, l]  # test whether we overwrote the expectation
    for j, l in itertools.product(xrange(0, J), xrange(0, L)):
        assert BNMTF.G[j, l] == 1. / lambdaG[j, l]

    # Initialisation of F and G using random draws from prior
    init_S, init_FG = 'exp', 'random'
    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.initialise(init_S, init_FG)

    for i, k in itertools.product(xrange(0, I), xrange(0, K)):
        assert BNMTF.F[i, k] != 1. / lambdaF[
            i, k]  # test whether we overwrote the expectation
    for k, l in itertools.product(xrange(0, K), xrange(0, L)):
        assert BNMTF.S[k, l] == 1. / lambdaS[k, l]
    for j, l in itertools.product(xrange(0, J), xrange(0, L)):
        assert BNMTF.G[j, l] != 1. / lambdaG[
            j, l]  # test whether we overwrote the expectation

    # Initialisation of F and G using Kmeans
    init_S, init_FG = 'exp', 'kmeans'
    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.initialise(init_S, init_FG)

    for i, k in itertools.product(xrange(0, I), xrange(0, K)):
        assert BNMTF.F[i, k] == 0.2 or BNMTF.F[i, k] == 1.2
    for j, l in itertools.product(xrange(0, J), xrange(0, L)):
        assert BNMTF.G[j, l] == 0.2 or BNMTF.G[j, l] == 1.2
    for k, l in itertools.product(xrange(0, K), xrange(0, L)):
        assert BNMTF.S[k, l] == 1. / lambdaS[k, l]
Пример #25
0
def test_approx_expectation():
    burn_in = 2
    thinning = 3  # so index 2,5,8 -> m=3,m=6,m=9
    (I, J, K, L) = (5, 3, 2, 4)
    Fs = [numpy.ones((I, K)) * 3 * m**2 for m in range(1, 10 + 1)]
    Ss = [numpy.ones((K, L)) * 2 * m**2 for m in range(1, 10 + 1)]
    Gs = [numpy.ones((J, L)) * 1 * m**2 for m in range(1, 10 + 1)
          ]  #first is 1's, second is 4's, third is 9's, etc.
    taus = [m**2 for m in range(1, 10 + 1)]

    expected_exp_tau = (9. + 36. + 81.) / 3.
    expected_exp_F = numpy.array([[9. + 36. + 81., 9. + 36. + 81.],
                                  [9. + 36. + 81., 9. + 36. + 81.],
                                  [9. + 36. + 81., 9. + 36. + 81.],
                                  [9. + 36. + 81., 9. + 36. + 81.],
                                  [9. + 36. + 81., 9. + 36. + 81.]])
    expected_exp_S = numpy.array([[(9. + 36. + 81.) * (2. / 3.),
                                   (9. + 36. + 81.) * (2. / 3.),
                                   (9. + 36. + 81.) * (2. / 3.),
                                   (9. + 36. + 81.) * (2. / 3.)],
                                  [(9. + 36. + 81.) * (2. / 3.),
                                   (9. + 36. + 81.) * (2. / 3.),
                                   (9. + 36. + 81.) * (2. / 3.),
                                   (9. + 36. + 81.) * (2. / 3.)]])
    expected_exp_G = numpy.array([[(9. + 36. + 81.) * (1. / 3.),
                                   (9. + 36. + 81.) * (1. / 3.),
                                   (9. + 36. + 81.) * (1. / 3.),
                                   (9. + 36. + 81.) * (1. / 3.)],
                                  [(9. + 36. + 81.) * (1. / 3.),
                                   (9. + 36. + 81.) * (1. / 3.),
                                   (9. + 36. + 81.) * (1. / 3.),
                                   (9. + 36. + 81.) * (1. / 3.)],
                                  [(9. + 36. + 81.) * (1. / 3.),
                                   (9. + 36. + 81.) * (1. / 3.),
                                   (9. + 36. + 81.) * (1. / 3.),
                                   (9. + 36. + 81.) * (1. / 3.)]])

    R = numpy.ones((I, J))
    M = numpy.ones((I, J))
    lambdaF = 2 * numpy.ones((I, K))
    lambdaS = 3 * numpy.ones((K, L))
    lambdaG = 4 * numpy.ones((J, L))
    alpha, beta = 3, 1
    priors = {
        'alpha': alpha,
        'beta': beta,
        'lambdaF': lambdaF,
        'lambdaS': lambdaS,
        'lambdaG': lambdaG
    }

    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.all_F = Fs
    BNMTF.all_S = Ss
    BNMTF.all_G = Gs
    BNMTF.all_tau = taus
    (exp_F, exp_S, exp_G,
     exp_tau) = BNMTF.approx_expectation(burn_in, thinning)

    assert expected_exp_tau == exp_tau
    assert numpy.array_equal(expected_exp_F, exp_F)
    assert numpy.array_equal(expected_exp_S, exp_S)
    assert numpy.array_equal(expected_exp_G, exp_G)
def test_init():
    # Test getting an exception when R and M are different sizes, and when R is not a 2D array.
    R1 = numpy.ones(3)
    M = numpy.ones((2,3))
    I,J,K,L = 5,3,1,2
    lambdaF = numpy.ones((I,K))
    lambdaS = numpy.ones((K,L))
    lambdaG = numpy.ones((J,L))
    alpha, beta = 3, 1    
    priors = { 'alpha':alpha, 'beta':beta, 'lambdaF':lambdaF, 'lambdaS':lambdaS, 'lambdaG':lambdaG }
    
    with pytest.raises(AssertionError) as error:
        bnmtf_gibbs_optimised(R1,M,K,L,priors)
    assert str(error.value) == "Input matrix R is not a two-dimensional array, but instead 1-dimensional."
    
    R2 = numpy.ones((4,3,2))
    with pytest.raises(AssertionError) as error:
        bnmtf_gibbs_optimised(R2,M,K,L,priors)
    assert str(error.value) == "Input matrix R is not a two-dimensional array, but instead 3-dimensional."
    
    R3 = numpy.ones((3,2))
    with pytest.raises(AssertionError) as error:
        bnmtf_gibbs_optimised(R3,M,K,L,priors)
    assert str(error.value) == "Input matrix R is not of the same size as the indicator matrix M: (3, 2) and (2, 3) respectively."
    
    # Similarly for lambdaF, lambdaS, lambdaG
    I,J,K,L = 2,3,1,2
    R4 = numpy.ones((2,3))
    lambdaF = numpy.ones((2+1,1))
    priors = { 'alpha':alpha, 'beta':beta, 'lambdaF':lambdaF, 'lambdaS':lambdaS, 'lambdaG':lambdaG }
    with pytest.raises(AssertionError) as error:
        bnmtf_gibbs_optimised(R4,M,K,L,priors)
    assert str(error.value) == "Prior matrix lambdaF has the wrong shape: (3, 1) instead of (2, 1)."
    
    lambdaF = numpy.ones((2,1))
    lambdaS = numpy.ones((1+1,2+1))
    priors = { 'alpha':alpha, 'beta':beta, 'lambdaF':lambdaF, 'lambdaS':lambdaS, 'lambdaG':lambdaG }
    with pytest.raises(AssertionError) as error:
        bnmtf_gibbs_optimised(R4,M,K,L,priors)
    assert str(error.value) == "Prior matrix lambdaS has the wrong shape: (2, 3) instead of (1, 2)."
    
    lambdaS = numpy.ones((1,2))
    lambdaG = numpy.ones((3,2+1))
    priors = { 'alpha':alpha, 'beta':beta, 'lambdaF':lambdaF, 'lambdaS':lambdaS, 'lambdaG':lambdaG }
    with pytest.raises(AssertionError) as error:
        bnmtf_gibbs_optimised(R4,M,K,L,priors)
    assert str(error.value) == "Prior matrix lambdaG has the wrong shape: (3, 3) instead of (3, 2)."
    
    # Test getting an exception if a row or column is entirely unknown
    lambdaF = numpy.ones((I,K))
    lambdaS = numpy.ones((K,L))
    lambdaG = numpy.ones((J,L))
    M1 = [[1,1,1],[0,0,0]]
    M2 = [[1,1,0],[1,0,0]]
    priors = { 'alpha':alpha, 'beta':beta, 'lambdaF':lambdaF, 'lambdaS':lambdaS, 'lambdaG':lambdaG }
    
    with pytest.raises(AssertionError) as error:
        bnmtf_gibbs_optimised(R4,M1,K,L,priors)
    assert str(error.value) == "Fully unobserved row in R, row 1."
    with pytest.raises(AssertionError) as error:
        bnmtf_gibbs_optimised(R4,M2,K,L,priors)
    assert str(error.value) == "Fully unobserved column in R, column 2."
    
    # Finally, a successful case
    I,J,K,L = 3,2,2,2
    R5 = 2*numpy.ones((I,J))
    lambdaF = numpy.ones((I,K))
    lambdaS = numpy.ones((K,L))
    lambdaG = numpy.ones((J,L))
    M = numpy.ones((I,J))
    priors = { 'alpha':alpha, 'beta':beta, 'lambdaF':lambdaF, 'lambdaS':lambdaS, 'lambdaG':lambdaG }
    BNMTF = bnmtf_gibbs_optimised(R5,M,K,L,priors)
    
    assert numpy.array_equal(BNMTF.R,R5)
    assert numpy.array_equal(BNMTF.M,M)
    assert BNMTF.I == I
    assert BNMTF.J == J
    assert BNMTF.K == K
    assert BNMTF.L == L
    assert BNMTF.size_Omega == I*J
    assert BNMTF.alpha == alpha
    assert BNMTF.beta == beta
    assert numpy.array_equal(BNMTF.lambdaF,lambdaF)
    assert numpy.array_equal(BNMTF.lambdaS,lambdaS)
    assert numpy.array_equal(BNMTF.lambdaG,lambdaG)
    
    # Test when lambdaF S G are integers
    I,J,K,L = 3,2,2,2
    R5 = 2*numpy.ones((I,J))
    lambdaF = 3
    lambdaS = 4
    lambdaG = 5
    M = numpy.ones((I,J))
    priors = { 'alpha':alpha, 'beta':beta, 'lambdaF':lambdaF, 'lambdaS':lambdaS, 'lambdaG':lambdaG }
    BNMTF = bnmtf_gibbs_optimised(R5,M,K,L,priors)
    
    assert numpy.array_equal(BNMTF.R,R5)
    assert numpy.array_equal(BNMTF.M,M)
    assert BNMTF.I == I
    assert BNMTF.J == J
    assert BNMTF.K == K
    assert BNMTF.L == L
    assert BNMTF.size_Omega == I*J
    assert BNMTF.alpha == alpha
    assert BNMTF.beta == beta
    assert numpy.array_equal(BNMTF.lambdaF,3*numpy.ones((I,K)))
    assert numpy.array_equal(BNMTF.lambdaS,4*numpy.ones((K,L)))
    assert numpy.array_equal(BNMTF.lambdaG,5*numpy.ones((J,L)))
Пример #27
0
R = numpy.loadtxt(input_folder + "R.txt")
M = numpy.ones((I, J))
#M = numpy.loadtxt(input_folder+"M.txt")
M_test = calc_inverse_M(numpy.loadtxt(input_folder + "M.txt"))

# Run the VB algorithm, <repeats> times
times_repeats = []
performances_repeats = []
for i in range(0, repeats):
    # Set all the seeds
    numpy.random.seed(3)
    random.seed(4)
    scipy.random.seed(5)

    # Run the classifier
    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.initialise(init_S, init_FG)
    BNMTF.run(iterations)

    # Extract the performances and timestamps across all iterations
    times_repeats.append(BNMTF.all_times)
    performances_repeats.append(BNMTF.all_performances)

# Check whether seed worked: all performances should be the same
assert all([numpy.array_equal(performances, performances_repeats[0]) for performances in performances_repeats]), \
    "Seed went wrong - performances not the same across repeats!"

# Print out the performances, and the average times
gibbs_all_times_average = list(numpy.average(times_repeats, axis=0))
gibbs_all_performances = performances_repeats[0]
print "gibbs_all_times_average = %s" % gibbs_all_times_average
Пример #28
0
lambdaF = numpy.ones((I, K)) / 10.
lambdaS = numpy.ones((K, L)) / 10.
lambdaG = numpy.ones((J, L)) / 10.
priors = {
    'alpha': alpha,
    'beta': beta,
    'lambdaF': lambdaF,
    'lambdaS': lambdaS,
    'lambdaG': lambdaG
}

# Load in data
(_, X_min, M, _, _, _, _) = load_Sanger(standardised=standardised)

# Run the Gibbs sampler
BNMTF = bnmtf_gibbs_optimised(X_min, M, K, L, priors)
BNMTF.initialise(init_S=init_S, init_FG=init_FG)
BNMTF.run(iterations)

# Also measure the performances on the training data
performances = BNMTF.predict(M, burnin, thinning)
print performances

# Plot the tau expectation values to check convergence
plt.plot(BNMTF.all_tau)

# Print the performances across iterations (MSE)
print "all_performances = %s" % BNMTF.all_performances['MSE']
'''
all_performances = [21.421197768547213, 6.3583912411769443, 4.1905150872867791, 3.5061419879402549, 3.2401904945288589, 3.1061034167609085, 3.0317883467556381, 2.9898304967466904, 2.9597551532537403, 2.9383503871636449, 2.9254205069284884, 2.9174938663203092, 2.9099499161449942, 2.9062570985645522, 2.9031497813762694, 2.8996412733085539, 2.8933585834560027, 2.8895195553715656, 2.8857902380755402, 2.8826893976014194, 2.8818869316601061, 2.879167497739882, 2.8810542681772024, 2.8774833883780992, 2.8764204742423538, 2.8690108092204154, 2.8725197104371101, 2.8674899173730819, 2.8647480905877538, 2.8626898729711101, 2.8593806833372533, 2.8565979156668497, 2.8534557000365472, 2.8486892736477119, 2.8442410993582907, 2.8424692916558705, 2.8381512986523192, 2.8355096495142504, 2.8377105511241458, 2.8331943900354211, 2.8304945163441682, 2.8222371793873791, 2.8187736926904341, 2.8142125953486694, 2.8123009163914499, 2.8141234880524757, 2.8117717085995619, 2.8059703725159513, 2.8066421988610086, 2.8055944223940736, 2.7997620244475629, 2.7982321373971564, 2.7983070728647377, 2.7970005782459055, 2.7979642884735307, 2.7950424452772769, 2.7929941875296747, 2.7881670598714097, 2.7871400065172396, 2.7816411365618094, 2.7818833625205426, 2.7744479857150077, 2.7730558648870067, 2.7711438261741677, 2.7718152242946443, 2.766961470633531, 2.7654244826300469, 2.7648775450099472, 2.759830401256282, 2.7604965970646158, 2.7565190544527352, 2.751721901129716, 2.7499851477712034, 2.7486851571379924, 2.7485557659818927, 2.7456710521803371, 2.7432921691336118, 2.7401237642034371, 2.7331482067432695, 2.7312389506051078, 2.7296576766624039, 2.7291115245357629, 2.7293278854136069, 2.7256303034532081, 2.7189792815582186, 2.7242246054550825, 2.7207063083185834, 2.7200853566998937, 2.7126052328111712, 2.7104206931618524, 2.7081642594411091, 2.7073972104837329, 2.7086208242493455, 2.7063489671878553, 2.7045511903566073, 2.700216589322828, 2.6985049679557624, 2.6985730146085505, 2.6962894918501195, 2.694909587802492, 2.6939198329458662, 2.6920355909168054, 2.6894797176987, 2.6847783992534953, 2.6876683766362706, 2.688193780216682, 2.6858834745328304, 2.6844699249666353, 2.6769096308066009, 2.6797439875382247, 2.673058143965358, 2.6750399642730227, 2.6737576049729102, 2.6753382124469258, 2.6737873508708652, 2.6721650293855133, 2.6693526407446213, 2.6711041560819941, 2.6677963869037198, 2.6672860460032015, 2.6632485977088582, 2.6644035358635141, 2.6631436079332325, 2.6566343838166611, 2.6601207297786087, 2.6581905318023229, 2.6552241136496249, 2.6555616306032186, 2.6584988784057306, 2.6543459980608777, 2.6533082700120127, 2.6562184235342698, 2.6542637954192925, 2.6572559096193937, 2.6521108798425561, 2.6495210566046978, 2.652091601040206, 2.6478717072987887, 2.6535204692860783, 2.6484477276931919, 2.6463775382012615, 2.6487259651895991, 2.6445075797922408, 2.6451952833711649, 2.6456859566066848, 2.6425004589946308, 2.6417269834150292, 2.6403920766773492, 2.641931022609934, 2.6394455417269764, 2.6394948161310552, 2.635177597321277, 2.6341701694023785, 2.6343344802468636, 2.633896529357159, 2.6293822572029852, 2.6345174371768203, 2.6318242656802711, 2.6290698618253541, 2.6312648199778086, 2.6313202334597539, 2.6310094109841677, 2.6286863155556759, 2.6253288958619159, 2.6289270393813338, 2.6243763804079756, 2.6232706349848942, 2.621821300192225, 2.6230450281460409, 2.6160250581134727, 2.6175467018929153, 2.6213394528875349, 2.6214553652760779, 2.6213017753765837, 2.6180231129473581, 2.6202923731318624, 2.6160445869792901, 2.6149866961059884, 2.615761203723916, 2.6163748388730719, 2.6218868760527982, 2.6162160790770415, 2.6187273080432374, 2.611746250624897, 2.6145067925057339, 2.6133254007677587, 2.610106943369733, 2.6107807661045968, 2.6061441473407734, 2.6076910653790688, 2.6054324568613949, 2.6060570722546119, 2.6065591089867088, 2.6046468492992938, 2.6048573287319403, 2.6053772866713394, 2.6030519570134132, 2.6049593587307931, 2.6038790593316556, 2.6003451555628168, 2.6021467791348858, 2.5983899443828498, 2.5993880603086974, 2.6021637563885203, 2.6011319024926265, 2.5989272776194934, 2.5932563440650016, 2.5968204720479457, 2.5934980043593359, 2.5931513131977133, 2.5919612142102508, 2.5911954259474386, 2.5918066515199873, 2.5923698084015294, 2.5941896416165426, 2.5916639563008821, 2.5898004464872559, 2.589922301931435, 2.5919901795965927, 2.5882568136716211, 2.5853604248886541, 2.5846250592866644, 2.5822657976232133, 2.5806467964027191, 2.5836269488500836, 2.5886548952242787, 2.5851038237303734, 2.5841579294147889, 2.5850935371886155, 2.5841115555956153, 2.5851836437551539, 2.5819626313589694, 2.5824423174346802, 2.5782770199985006, 2.5811036722631724, 2.5796037219889145, 2.5798519970942211, 2.5767618214294234, 2.5789373301009637, 2.5747406167638345, 2.5792743409560823, 2.576747040697358, 2.5776113195519841, 2.574209829919416, 2.5747708112227281, 2.5742246909935345, 2.5710951080383313, 2.5729644361934052, 2.5719885459633782, 2.5733376347739707, 2.5743083095913781, 2.5718457673453128, 2.5721229856256627, 2.568116109002069, 2.5663951523490223, 2.5680249463868012, 2.562615184441392, 2.5632287918776431, 2.564661376928254, 2.5664542266678576, 2.5625726421627473, 2.5639679687987473, 2.5645019343528888, 2.5630600889582889, 2.5595768274188937, 2.5577499105034511, 2.5604062927978211, 2.5560835459518203, 2.5573004408357631, 2.55920318802956, 2.5593419743212196, 2.5592681447677155, 2.558165598393872, 2.5589704574204881, 2.5551933203606425, 2.5543929280887117, 2.5536716555763523, 2.5540659861496691, 2.553909507448946, 2.5504221837864924, 2.5483296161255131, 2.5483656716265242, 2.5469882173440195, 2.5486124491456081, 2.5465970834627285, 2.5483524676376943, 2.5445821380246687, 2.5417514209496415, 2.5404979096947895, 2.5388747145000305, 2.5396636886940835, 2.5388957255427522, 2.5403161929143523, 2.5404497566927766, 2.5393327770304475, 2.5368184304321226, 2.5366729361209202, 2.5375696346383263, 2.5378043098788763, 2.5368839646408161, 2.5346830439347219, 2.5316015493665618, 2.5315622501818775, 2.5349907009824353, 2.5328213784189426, 2.5291148147614781, 2.5303271832997787, 2.5329121015277587, 2.5298818563249958, 2.5306162405527082, 2.528534008147949, 2.5307865731126946, 2.530602843803305, 2.5319366760976054, 2.5338630376209639, 2.5326934683491253, 2.5306325210511549, 2.5317225485264334, 2.5263184086523123, 2.5262510313284965, 2.5264629090442701, 2.5238843934178781, 2.5237533315013962, 2.5236819906911303, 2.5250513838553967, 2.5248798947634228, 2.5257865523350076, 2.5270958218177064, 2.5248001023254449, 2.5230538760229009, 2.5199145022087222, 2.5211728535801536, 2.5205171702099989, 2.5165420437914818, 2.5167343762142007, 2.5144070490112078, 2.516257932430638, 2.5168061120129464, 2.5143041502266628, 2.5129286780712836, 2.5112573173520354, 2.5110043962215749, 2.5102762368938532, 2.5083685702124114, 2.5099884368430527, 2.5091212146302291, 2.5115196818373327, 2.5114366707347142, 2.510391327866333, 2.5070728838314471, 2.5064036285821216, 2.5096419011502706, 2.5061425432723596, 2.5071796863923637, 2.5059793742513028, 2.504708930899151, 2.5057530148071976, 2.501061520429579, 2.5038890369533355, 2.503046227017014, 2.5026337896815658, 2.5024368750879593, 2.5013637172210768, 2.5049310161099001, 2.5018006203042473, 2.5046718056750232, 2.5042601547876782, 2.5052724358937795, 2.5039828866701792, 2.4995150541295139, 2.5005612433307802, 2.4991316790435909, 2.4965953488279009, 2.4989534703730039, 2.4991746090331812, 2.4979916122339261, 2.4982039232876625, 2.4978864322734444, 2.4962628254821104, 2.497144330890892, 2.498514191276644, 2.4981536548063463, 2.4963938357459994, 2.4949951161538335, 2.4951106908979881, 2.4958613555329636, 2.4981289291986619, 2.4974860677594277, 2.4957707674113423, 2.4944560480312497, 2.4954138931089909, 2.493115400623692, 2.4940313333392692, 2.4933273583654811, 2.4927003579334372, 2.4900455953111251, 2.4933163671568432, 2.4917768503235878, 2.4912174811147243, 2.4908288262996487, 2.4907817992399583, 2.4932391584542031, 2.4907302035988477, 2.4921300365802357, 2.4882521442517218, 2.4887462843809907, 2.4893718284424065, 2.4882348355180608, 2.4867851340429303, 2.4871214947095996, 2.4835867776300988, 2.4867250222901198, 2.4874951220593058, 2.4858794530869295, 2.4852071258559358, 2.4820395448650636, 2.4823092787810337, 2.4779599771964294, 2.4791845723506016, 2.477108928551043, 2.4778670236685487, 2.47654199686938, 2.4785038950583931, 2.475040799203641, 2.4728145115314373, 2.4718257125647929, 2.4769530303289491, 2.4777158180737455, 2.4730855837094587, 2.4752885729661993, 2.4712323566640069, 2.4718970383297667, 2.4734310124577044, 2.4711750589989117, 2.4719086437242574, 2.471575999272726, 2.4697846874637728, 2.4705153658648564, 2.4688532195439024, 2.4702164254884296, 2.4719420675924586, 2.4699568420671407, 2.469504644932583, 2.4707574546602888, 2.4662071401260635, 2.4662311621692576, 2.4619603116509508, 2.4642646166171227, 2.4667129611943777, 2.4657735091526596, 2.461118169231026, 2.4629121177958502, 2.4590348040748702, 2.4582676307167004, 2.4598776876562143, 2.4607219308810935, 2.4581935192603162, 2.4558945694632155, 2.4564966038513671, 2.4589417459491303, 2.4576981428766782, 2.4576263925016604, 2.4551388322009498, 2.4566508398742628, 2.4542368659105107, 2.4524367052047071, 2.45178142628852, 2.4552949780968811, 2.4558424987136682, 2.4554169934897128, 2.4553103548587902, 2.4569274157219811, 2.4560511775906262, 2.4553892678085556, 2.4565021131182809, 2.4545272386897783, 2.4559748769454841, 2.4529930725071272, 2.4540310528654175, 2.4555575477941511, 2.4542092955622841, 2.4522553768748625, 2.4520992654349496, 2.4503163551405218, 2.4507708258806979, 2.45288269264426, 2.4483271088993726, 2.4456838649243968, 2.4441796565727985, 2.4458767802140633, 2.4447817316756808, 2.4450548036466886, 2.4443901938067141, 2.4445581487201729, 2.4440218169455723, 2.4434951025212257, 2.4416724830530221, 2.4402548866490315, 2.4446625472524963, 2.4408827883048358, 2.4424834167749618, 2.4428078494609444, 2.4409503591034265, 2.4410381296420072, 2.4426618631120371, 2.4400938248597721, 2.4429047584511214, 2.4385188594548102, 2.4406540272669419, 2.4401767580037168, 2.4388083549058193, 2.4386768589510828, 2.4377838552629183, 2.4367258794557261, 2.4338775293628534, 2.434110883046174, 2.4353256169833313, 2.4309685465107012, 2.4323231389452413, 2.4318592441553051, 2.4289555222997783, 2.428341411200893, 2.4256793931108067, 2.4264862697593204, 2.425459118355652, 2.4267497767420751, 2.4266449044385046, 2.4244994338410861, 2.427791628026327, 2.4240341240754528, 2.4252277299379017, 2.4267749231882703, 2.4258137214514606, 2.4254079264070838, 2.4256009575517825, 2.4252375201422653, 2.4240078365492539, 2.4247485664184381, 2.4191715970641141, 2.4177084402800286, 2.4167797849480759, 2.4142933175736312, 2.4173592905286707, 2.4165319906631977, 2.4187166935761364, 2.4147768130919758, 2.4163729451515534, 2.4177879133944944, 2.4171241540741262, 2.4196054028251273, 2.4173288270268531, 2.4168541658884992, 2.417982456482219, 2.4135986285586992, 2.4147615692918274, 2.4118447289442937, 2.4121782971050458, 2.409355290969291, 2.4083458941010538, 2.4074728522012325, 2.4056330418325005, 2.4046256444888039, 2.4064354529961056, 2.4048393660259908, 2.4072886833300764, 2.4056939016798515, 2.40596813049192, 2.4075949048941445, 2.4049916008389829, 2.4061710628567266, 2.4046424348696487, 2.4028381424176839, 2.4026902099476177, 2.4024878419756828, 2.4010249084337096, 2.398883577252898, 2.399079375015551, 2.3976259428932942, 2.3994370351325287, 2.3989649725586601, 2.3986647353776656, 2.4002633200802861, 2.4014805548818958, 2.3966766948511768, 2.3954537236533424, 2.3987389799479226, 2.394355438664352, 2.3941168935565811, 2.39686384784099, 2.3976205499634249, 2.395745933841642, 2.3945812425525372, 2.392884594591544, 2.3954295822522944, 2.3940255194807265, 2.3946973532599087, 2.3926577877205149, 2.3950174520234069, 2.392480227419453, 2.389857539007842, 2.3912854590429884, 2.3883992688670719, 2.3908834958292791, 2.3915550903958778, 2.3877242496974391, 2.3866911249839928, 2.3842309856996704, 2.3857586747225352, 2.383856214904684, 2.3827729205814894, 2.3820464237752157, 2.3818849524093553, 2.3831426098990423, 2.3810319334690622, 2.3813689835726715, 2.3823108518417211, 2.3791879347825704, 2.37907449773668, 2.3780177248592937, 2.3780015730382731, 2.3773655182692064, 2.3768346413297339, 2.3783605671318164, 2.3749216050166431, 2.3742762542911717, 2.371988778361402, 2.3739412171837677, 2.374285629602487, 2.3736258808844548, 2.3732296799928498, 2.3730243145968455, 2.3712471381869014, 2.3720779046771501, 2.3676644505425237, 2.371033017383577, 2.3684924802827467, 2.3701580729394651, 2.368183587929146, 2.3671170643465111, 2.3656463848315661, 2.3652835384039901, 2.3649967344579728, 2.3628216648684646, 2.3637369029909818, 2.3627748162265125, 2.3629257747886916, 2.3635167164122843, 2.3617756366624967, 2.3632902833557146, 2.3603235380521124, 2.3589162111974957, 2.3564076775404725, 2.3572814558069446, 2.3575787723942421, 2.3598336482543454, 2.3561518389513041, 2.3582582267144114, 2.3569050557776734, 2.3556367280292143, 2.3598859107429115, 2.3570348525140794, 2.3552765098212061, 2.356346653237599, 2.3575799133869402, 2.3547985165298693, 2.358438308310173, 2.3556338434439903, 2.3549034878401196, 2.3516814089949363, 2.3548224693125559, 2.3523973934729154, 2.3525582169967092, 2.3516620051795516, 2.3525889421081203, 2.3526167853425854, 2.3502297484347259, 2.3509800760110271, 2.3488714804137625, 2.3465479971997145, 2.3517864973221112, 2.3470238043849925, 2.3468127291455101, 2.3494346574627341, 2.3475592521792348, 2.3461834400145287, 2.3475963268300188, 2.3441964818065757, 2.3449812733282607, 2.3435179988744843, 2.3427271354087162, 2.3451719718843069, 2.3423222857306953, 2.3443652568947271, 2.3442507817914882, 2.3453463985426239, 2.3427916188305873, 2.3424037845015557, 2.3408258058628353, 2.340705479319555, 2.3409886259357298, 2.338497686154339, 2.3396918257902017, 2.3370095488175693, 2.3384533198423307, 2.3363531671618518, 2.3359213321821342, 2.3376687774637714, 2.3367156006870937, 2.3374922787114283, 2.3375235916421868, 2.3382675159825363, 2.3364063328936089, 2.3374733170731092, 2.3341625953310041, 2.335614091433285, 2.3339590855042154, 2.3333835945550274, 2.3354003175119065, 2.333926253576156, 2.3346001717657279, 2.3346439256680185, 2.3313431015807757, 2.3315182392513192, 2.3311568037853023, 2.3284888788945892, 2.3290307071970409, 2.325106459967444, 2.329137407929764, 2.3270823801388842, 2.327336614007006, 2.3286829796555537, 2.3274777576485515, 2.3296765594343447, 2.3310986270128695, 2.3239320426765695, 2.325034897593738, 2.3261905195339443, 2.3225233536197978, 2.324621604494292, 2.3251496479164011, 2.3243036199149953, 2.3222132510600173, 2.3220011975448043, 2.3228468551347818, 2.3248375748167578, 2.3273315003257524, 2.320425958736962, 2.3227825745447599, 2.3209149998623468, 2.323342251157539, 2.3207611555344094, 2.3197780034228344, 2.3181917454164584, 2.3211002276736217, 2.3199750237362382, 2.3197363109079787, 2.3193518064643786, 2.3174773405014162, 2.3195516218158834, 2.3166660158291865, 2.3178289819867142, 2.316712961624761, 2.3172574140684468, 2.3158184283703771, 2.3150919850533476, 2.3147977281071159, 2.3170142771232443, 2.3130506272839182, 2.3154421103661949, 2.3127606232209961, 2.3140721624794285, 2.3123285349341165, 2.3125772213526377, 2.311254045396073, 2.3134533834377735, 2.315004137052997, 2.3128650560568609, 2.3135705392065504, 2.3114503634005725, 2.311801982556243, 2.3144572377748203, 2.3139159539659744, 2.3133424482778286, 2.3140797479536759, 2.3137826578620708, 2.3146474763587763, 2.3143563209524909, 2.3153192405046528, 2.3128254216582604, 2.3122723216477805, 2.3121664320278725, 2.3102799942558776, 2.311886031839133, 2.3092467929853635, 2.3105145584588631, 2.31098416915186, 2.3096793961817905, 2.3112703335106577, 2.3118936882514229, 2.3112622521471775, 2.3117414690829445, 2.3065169006361432, 2.3078849801400927, 2.308358335074852, 2.3089415794237116, 2.3091123582720523, 2.311284369143809, 2.3082666456801286, 2.3057030767461648, 2.3069749384712477, 2.3057418714287845, 2.3068795863982512, 2.3026504278510904, 2.3060037430694971, 2.3067677842789238, 2.3076478032453727, 2.3059910221946143, 2.3028523084846837, 2.302755679982309, 2.30388550669768, 2.3042878520972923, 2.3026756729130944, 2.303922690596476, 2.3027344743371945, 2.3037110423316234, 2.3035411375623851, 2.3053965418794022, 2.3043063134457658, 2.304535188333162, 2.3016233321686648, 2.3005348112766564, 2.2995774593642642, 2.2991760064220412, 2.2980212064863346, 2.2986373414248988, 2.2992576606572075, 2.3003870576214971, 2.2986848298566334, 2.2999981116088373, 2.3006123223103243, 2.2994209266525809, 2.3002881923648131, 2.3015660218520373, 2.2995223059481904, 2.2995656428041871, 2.3026421867144666, 2.2998614288885153, 2.2999953626809493, 2.3020222850902048, 2.3006742715610691, 2.300193992002062, 2.2968363225963362, 2.2973415092055052, 2.2954936194397475, 2.297249480972452, 2.3001183368183105, 2.2983762180120415, 2.2955051204226131, 2.2994068488318531, 2.2962605222845425, 2.2979940696864447, 2.3020779698423426, 2.3000583203960718, 2.297851868846541, 2.2982425337249177, 2.2990277119829035, 2.298749961522442, 2.2982441186003264, 2.2987077368526427, 2.297191665655987, 2.2962272985001206, 2.2953192428886595, 2.2948044255945219, 2.2954207949310188, 2.2963841522010231, 2.2939609088818274, 2.2962906579083451, 2.2953749176696543, 2.2959042972819086, 2.2966802178272934, 2.2954022689964364, 2.2970280395297031, 2.296973123148244, 2.2968407354011893, 2.2948771031725665, 2.2944170920386453, 2.2954111673442483, 2.2962969883186117, 2.2975692034578192, 2.2941843469526986, 2.2954206866064708, 2.2953134242929356, 2.2939944765324416, 2.2971163915180743, 2.2931141566685262, 2.2928134192505705, 2.2945132699933519, 2.2921416784340152, 2.2926027276561274, 2.2915791274972332, 2.2929808700057617, 2.2917925586090875, 2.2931807517490066, 2.2918499439325379, 2.294656536794129, 2.2903843996772908, 2.2911382866540335, 2.2904974519806625, 2.2897702312877666, 2.2911246725585839, 2.2907808237832969, 2.2896332257897503, 2.2914290893373641, 2.2888734009114251, 2.2864846172436271, 2.2873104468727234, 2.2910055907382381, 2.2871051173398942, 2.2851346521781171, 2.2908672005815331, 2.2889027584442716, 2.2871518753522064, 2.2867596830848993, 2.2883785660634817, 2.2866042120489825, 2.28817920069877, 2.2889315268461385, 2.2865765569082743, 2.2855485621528784, 2.286919957686147, 2.2864879855246976, 2.28793584438414, 2.2864763690332417, 2.2857059197207232, 2.2860979263484182, 2.2876492958779324, 2.2872504920011165, 2.287417204333305, 2.2870683234986808, 2.286450753230802, 2.2850898622748983, 2.2822627854188391, 2.286911979638155, 2.2869755785453987, 2.2824761267430418, 2.2857956003906659, 2.2820265876282271, 2.2810997792727274, 2.2815641666821307, 2.2835538965411621, 2.2805643086175853, 2.2816775531981648, 2.2825254787572318, 2.2811948782264153, 2.2825146344990497, 2.282754111903103, 2.2830908596699824, 2.2812709387001231, 2.2819972410443086, 2.2818875430567309, 2.2825457949765755, 2.2824629557105252, 2.2838960229814922, 2.2803418045361354, 2.2826600731156175, 2.2813350704033679, 2.2807526587456661, 2.2828164394800421, 2.2851422417793446, 2.2809753873876679, 2.2812822991999315, 2.2823394535889605, 2.2835923142160741, 2.2829732001482363, 2.2796921617443617, 2.2841529004742158, 2.2799784011382194, 2.2820589274177463, 2.2813805541757177, 2.280464881304062, 2.2807335174580259, 2.2818007471503763, 2.2816039765172125, 2.2782487208983784, 2.2781263239720144, 2.2820882923159114, 2.2821401768336869, 2.2789695945552166, 2.2791322615876193, 2.2797851108479716, 2.2819833638823281, 2.2821960942913475, 2.2809085038292123, 2.2799407423350004, 2.2814388033860018, 2.2775334764834683, 2.2816922200030416, 2.2811528164968875, 2.2801223657401284, 2.2797202598380553]
'''
Пример #29
0
def test_init():
    # Test getting an exception when R and M are different sizes, and when R is not a 2D array.
    R1 = numpy.ones(3)
    M = numpy.ones((2, 3))
    I, J, K, L = 5, 3, 1, 2
    lambdaF = numpy.ones((I, K))
    lambdaS = numpy.ones((K, L))
    lambdaG = numpy.ones((J, L))
    alpha, beta = 3, 1
    priors = {
        'alpha': alpha,
        'beta': beta,
        'lambdaF': lambdaF,
        'lambdaS': lambdaS,
        'lambdaG': lambdaG
    }

    with pytest.raises(AssertionError) as error:
        bnmtf_gibbs_optimised(R1, M, K, L, priors)
    assert str(
        error.value
    ) == "Input matrix R is not a two-dimensional array, but instead 1-dimensional."

    R2 = numpy.ones((4, 3, 2))
    with pytest.raises(AssertionError) as error:
        bnmtf_gibbs_optimised(R2, M, K, L, priors)
    assert str(
        error.value
    ) == "Input matrix R is not a two-dimensional array, but instead 3-dimensional."

    R3 = numpy.ones((3, 2))
    with pytest.raises(AssertionError) as error:
        bnmtf_gibbs_optimised(R3, M, K, L, priors)
    assert str(
        error.value
    ) == "Input matrix R is not of the same size as the indicator matrix M: (3, 2) and (2, 3) respectively."

    # Similarly for lambdaF, lambdaS, lambdaG
    I, J, K, L = 2, 3, 1, 2
    R4 = numpy.ones((2, 3))
    lambdaF = numpy.ones((2 + 1, 1))
    priors = {
        'alpha': alpha,
        'beta': beta,
        'lambdaF': lambdaF,
        'lambdaS': lambdaS,
        'lambdaG': lambdaG
    }
    with pytest.raises(AssertionError) as error:
        bnmtf_gibbs_optimised(R4, M, K, L, priors)
    assert str(
        error.value
    ) == "Prior matrix lambdaF has the wrong shape: (3, 1) instead of (2, 1)."

    lambdaF = numpy.ones((2, 1))
    lambdaS = numpy.ones((1 + 1, 2 + 1))
    priors = {
        'alpha': alpha,
        'beta': beta,
        'lambdaF': lambdaF,
        'lambdaS': lambdaS,
        'lambdaG': lambdaG
    }
    with pytest.raises(AssertionError) as error:
        bnmtf_gibbs_optimised(R4, M, K, L, priors)
    assert str(
        error.value
    ) == "Prior matrix lambdaS has the wrong shape: (2, 3) instead of (1, 2)."

    lambdaS = numpy.ones((1, 2))
    lambdaG = numpy.ones((3, 2 + 1))
    priors = {
        'alpha': alpha,
        'beta': beta,
        'lambdaF': lambdaF,
        'lambdaS': lambdaS,
        'lambdaG': lambdaG
    }
    with pytest.raises(AssertionError) as error:
        bnmtf_gibbs_optimised(R4, M, K, L, priors)
    assert str(
        error.value
    ) == "Prior matrix lambdaG has the wrong shape: (3, 3) instead of (3, 2)."

    # Test getting an exception if a row or column is entirely unknown
    lambdaF = numpy.ones((I, K))
    lambdaS = numpy.ones((K, L))
    lambdaG = numpy.ones((J, L))
    M1 = [[1, 1, 1], [0, 0, 0]]
    M2 = [[1, 1, 0], [1, 0, 0]]
    priors = {
        'alpha': alpha,
        'beta': beta,
        'lambdaF': lambdaF,
        'lambdaS': lambdaS,
        'lambdaG': lambdaG
    }

    with pytest.raises(AssertionError) as error:
        bnmtf_gibbs_optimised(R4, M1, K, L, priors)
    assert str(error.value) == "Fully unobserved row in R, row 1."
    with pytest.raises(AssertionError) as error:
        bnmtf_gibbs_optimised(R4, M2, K, L, priors)
    assert str(error.value) == "Fully unobserved column in R, column 2."

    # Finally, a successful case
    I, J, K, L = 3, 2, 2, 2
    R5 = 2 * numpy.ones((I, J))
    lambdaF = numpy.ones((I, K))
    lambdaS = numpy.ones((K, L))
    lambdaG = numpy.ones((J, L))
    M = numpy.ones((I, J))
    priors = {
        'alpha': alpha,
        'beta': beta,
        'lambdaF': lambdaF,
        'lambdaS': lambdaS,
        'lambdaG': lambdaG
    }
    BNMTF = bnmtf_gibbs_optimised(R5, M, K, L, priors)

    assert numpy.array_equal(BNMTF.R, R5)
    assert numpy.array_equal(BNMTF.M, M)
    assert BNMTF.I == I
    assert BNMTF.J == J
    assert BNMTF.K == K
    assert BNMTF.L == L
    assert BNMTF.size_Omega == I * J
    assert BNMTF.alpha == alpha
    assert BNMTF.beta == beta
    assert numpy.array_equal(BNMTF.lambdaF, lambdaF)
    assert numpy.array_equal(BNMTF.lambdaS, lambdaS)
    assert numpy.array_equal(BNMTF.lambdaG, lambdaG)

    # Test when lambdaF S G are integers
    I, J, K, L = 3, 2, 2, 2
    R5 = 2 * numpy.ones((I, J))
    lambdaF = 3
    lambdaS = 4
    lambdaG = 5
    M = numpy.ones((I, J))
    priors = {
        'alpha': alpha,
        'beta': beta,
        'lambdaF': lambdaF,
        'lambdaS': lambdaS,
        'lambdaG': lambdaG
    }
    BNMTF = bnmtf_gibbs_optimised(R5, M, K, L, priors)

    assert numpy.array_equal(BNMTF.R, R5)
    assert numpy.array_equal(BNMTF.M, M)
    assert BNMTF.I == I
    assert BNMTF.J == J
    assert BNMTF.K == K
    assert BNMTF.L == L
    assert BNMTF.size_Omega == I * J
    assert BNMTF.alpha == alpha
    assert BNMTF.beta == beta
    assert numpy.array_equal(BNMTF.lambdaF, 3 * numpy.ones((I, K)))
    assert numpy.array_equal(BNMTF.lambdaS, 4 * numpy.ones((K, L)))
    assert numpy.array_equal(BNMTF.lambdaG, 5 * numpy.ones((J, L)))
Пример #30
0
def test_alpha_s():
    BNMTF = bnmtf_gibbs_optimised(R, M, K, L, priors)
    BNMTF.initialise(init_S, init_FG)
    BNMTF.tau = 3.
    alpha_s = alpha + 6.
    assert BNMTF.alpha_s() == alpha_s
Пример #31
0
I, J, K, L = 622,139,10,5
init_S = 'random' #'exp' #
init_FG = 'kmeans' #'exp' #

alpha, beta = 1., 1.
lambdaF = numpy.ones((I,K))/10.
lambdaS = numpy.ones((K,L))/10.
lambdaG = numpy.ones((J,L))/10.
priors = { 'alpha':alpha, 'beta':beta, 'lambdaF':lambdaF, 'lambdaS':lambdaS, 'lambdaG':lambdaG }

# Load in data
(_,X_min,M,_,_,_,_) = load_Sanger(standardised=standardised)

# Run the Gibbs sampler
BNMTF = bnmtf_gibbs_optimised(X_min,M,K,L,priors)
BNMTF.initialise(init_S=init_S,init_FG=init_FG)
BNMTF.run(iterations)

# Also measure the performances on the training data
performances = BNMTF.predict(M,burnin,thinning)
print performances

# Plot the tau expectation values to check convergence
plt.plot(BNMTF.all_tau)

# Print the performances across iterations (MSE)
print "all_performances = %s" % BNMTF.all_performances['MSE']

'''
all_performances = [21.421197768547213, 6.3583912411769443, 4.1905150872867791, 3.5061419879402549, 3.2401904945288589, 3.1061034167609085, 3.0317883467556381, 2.9898304967466904, 2.9597551532537403, 2.9383503871636449, 2.9254205069284884, 2.9174938663203092, 2.9099499161449942, 2.9062570985645522, 2.9031497813762694, 2.8996412733085539, 2.8933585834560027, 2.8895195553715656, 2.8857902380755402, 2.8826893976014194, 2.8818869316601061, 2.879167497739882, 2.8810542681772024, 2.8774833883780992, 2.8764204742423538, 2.8690108092204154, 2.8725197104371101, 2.8674899173730819, 2.8647480905877538, 2.8626898729711101, 2.8593806833372533, 2.8565979156668497, 2.8534557000365472, 2.8486892736477119, 2.8442410993582907, 2.8424692916558705, 2.8381512986523192, 2.8355096495142504, 2.8377105511241458, 2.8331943900354211, 2.8304945163441682, 2.8222371793873791, 2.8187736926904341, 2.8142125953486694, 2.8123009163914499, 2.8141234880524757, 2.8117717085995619, 2.8059703725159513, 2.8066421988610086, 2.8055944223940736, 2.7997620244475629, 2.7982321373971564, 2.7983070728647377, 2.7970005782459055, 2.7979642884735307, 2.7950424452772769, 2.7929941875296747, 2.7881670598714097, 2.7871400065172396, 2.7816411365618094, 2.7818833625205426, 2.7744479857150077, 2.7730558648870067, 2.7711438261741677, 2.7718152242946443, 2.766961470633531, 2.7654244826300469, 2.7648775450099472, 2.759830401256282, 2.7604965970646158, 2.7565190544527352, 2.751721901129716, 2.7499851477712034, 2.7486851571379924, 2.7485557659818927, 2.7456710521803371, 2.7432921691336118, 2.7401237642034371, 2.7331482067432695, 2.7312389506051078, 2.7296576766624039, 2.7291115245357629, 2.7293278854136069, 2.7256303034532081, 2.7189792815582186, 2.7242246054550825, 2.7207063083185834, 2.7200853566998937, 2.7126052328111712, 2.7104206931618524, 2.7081642594411091, 2.7073972104837329, 2.7086208242493455, 2.7063489671878553, 2.7045511903566073, 2.700216589322828, 2.6985049679557624, 2.6985730146085505, 2.6962894918501195, 2.694909587802492, 2.6939198329458662, 2.6920355909168054, 2.6894797176987, 2.6847783992534953, 2.6876683766362706, 2.688193780216682, 2.6858834745328304, 2.6844699249666353, 2.6769096308066009, 2.6797439875382247, 2.673058143965358, 2.6750399642730227, 2.6737576049729102, 2.6753382124469258, 2.6737873508708652, 2.6721650293855133, 2.6693526407446213, 2.6711041560819941, 2.6677963869037198, 2.6672860460032015, 2.6632485977088582, 2.6644035358635141, 2.6631436079332325, 2.6566343838166611, 2.6601207297786087, 2.6581905318023229, 2.6552241136496249, 2.6555616306032186, 2.6584988784057306, 2.6543459980608777, 2.6533082700120127, 2.6562184235342698, 2.6542637954192925, 2.6572559096193937, 2.6521108798425561, 2.6495210566046978, 2.652091601040206, 2.6478717072987887, 2.6535204692860783, 2.6484477276931919, 2.6463775382012615, 2.6487259651895991, 2.6445075797922408, 2.6451952833711649, 2.6456859566066848, 2.6425004589946308, 2.6417269834150292, 2.6403920766773492, 2.641931022609934, 2.6394455417269764, 2.6394948161310552, 2.635177597321277, 2.6341701694023785, 2.6343344802468636, 2.633896529357159, 2.6293822572029852, 2.6345174371768203, 2.6318242656802711, 2.6290698618253541, 2.6312648199778086, 2.6313202334597539, 2.6310094109841677, 2.6286863155556759, 2.6253288958619159, 2.6289270393813338, 2.6243763804079756, 2.6232706349848942, 2.621821300192225, 2.6230450281460409, 2.6160250581134727, 2.6175467018929153, 2.6213394528875349, 2.6214553652760779, 2.6213017753765837, 2.6180231129473581, 2.6202923731318624, 2.6160445869792901, 2.6149866961059884, 2.615761203723916, 2.6163748388730719, 2.6218868760527982, 2.6162160790770415, 2.6187273080432374, 2.611746250624897, 2.6145067925057339, 2.6133254007677587, 2.610106943369733, 2.6107807661045968, 2.6061441473407734, 2.6076910653790688, 2.6054324568613949, 2.6060570722546119, 2.6065591089867088, 2.6046468492992938, 2.6048573287319403, 2.6053772866713394, 2.6030519570134132, 2.6049593587307931, 2.6038790593316556, 2.6003451555628168, 2.6021467791348858, 2.5983899443828498, 2.5993880603086974, 2.6021637563885203, 2.6011319024926265, 2.5989272776194934, 2.5932563440650016, 2.5968204720479457, 2.5934980043593359, 2.5931513131977133, 2.5919612142102508, 2.5911954259474386, 2.5918066515199873, 2.5923698084015294, 2.5941896416165426, 2.5916639563008821, 2.5898004464872559, 2.589922301931435, 2.5919901795965927, 2.5882568136716211, 2.5853604248886541, 2.5846250592866644, 2.5822657976232133, 2.5806467964027191, 2.5836269488500836, 2.5886548952242787, 2.5851038237303734, 2.5841579294147889, 2.5850935371886155, 2.5841115555956153, 2.5851836437551539, 2.5819626313589694, 2.5824423174346802, 2.5782770199985006, 2.5811036722631724, 2.5796037219889145, 2.5798519970942211, 2.5767618214294234, 2.5789373301009637, 2.5747406167638345, 2.5792743409560823, 2.576747040697358, 2.5776113195519841, 2.574209829919416, 2.5747708112227281, 2.5742246909935345, 2.5710951080383313, 2.5729644361934052, 2.5719885459633782, 2.5733376347739707, 2.5743083095913781, 2.5718457673453128, 2.5721229856256627, 2.568116109002069, 2.5663951523490223, 2.5680249463868012, 2.562615184441392, 2.5632287918776431, 2.564661376928254, 2.5664542266678576, 2.5625726421627473, 2.5639679687987473, 2.5645019343528888, 2.5630600889582889, 2.5595768274188937, 2.5577499105034511, 2.5604062927978211, 2.5560835459518203, 2.5573004408357631, 2.55920318802956, 2.5593419743212196, 2.5592681447677155, 2.558165598393872, 2.5589704574204881, 2.5551933203606425, 2.5543929280887117, 2.5536716555763523, 2.5540659861496691, 2.553909507448946, 2.5504221837864924, 2.5483296161255131, 2.5483656716265242, 2.5469882173440195, 2.5486124491456081, 2.5465970834627285, 2.5483524676376943, 2.5445821380246687, 2.5417514209496415, 2.5404979096947895, 2.5388747145000305, 2.5396636886940835, 2.5388957255427522, 2.5403161929143523, 2.5404497566927766, 2.5393327770304475, 2.5368184304321226, 2.5366729361209202, 2.5375696346383263, 2.5378043098788763, 2.5368839646408161, 2.5346830439347219, 2.5316015493665618, 2.5315622501818775, 2.5349907009824353, 2.5328213784189426, 2.5291148147614781, 2.5303271832997787, 2.5329121015277587, 2.5298818563249958, 2.5306162405527082, 2.528534008147949, 2.5307865731126946, 2.530602843803305, 2.5319366760976054, 2.5338630376209639, 2.5326934683491253, 2.5306325210511549, 2.5317225485264334, 2.5263184086523123, 2.5262510313284965, 2.5264629090442701, 2.5238843934178781, 2.5237533315013962, 2.5236819906911303, 2.5250513838553967, 2.5248798947634228, 2.5257865523350076, 2.5270958218177064, 2.5248001023254449, 2.5230538760229009, 2.5199145022087222, 2.5211728535801536, 2.5205171702099989, 2.5165420437914818, 2.5167343762142007, 2.5144070490112078, 2.516257932430638, 2.5168061120129464, 2.5143041502266628, 2.5129286780712836, 2.5112573173520354, 2.5110043962215749, 2.5102762368938532, 2.5083685702124114, 2.5099884368430527, 2.5091212146302291, 2.5115196818373327, 2.5114366707347142, 2.510391327866333, 2.5070728838314471, 2.5064036285821216, 2.5096419011502706, 2.5061425432723596, 2.5071796863923637, 2.5059793742513028, 2.504708930899151, 2.5057530148071976, 2.501061520429579, 2.5038890369533355, 2.503046227017014, 2.5026337896815658, 2.5024368750879593, 2.5013637172210768, 2.5049310161099001, 2.5018006203042473, 2.5046718056750232, 2.5042601547876782, 2.5052724358937795, 2.5039828866701792, 2.4995150541295139, 2.5005612433307802, 2.4991316790435909, 2.4965953488279009, 2.4989534703730039, 2.4991746090331812, 2.4979916122339261, 2.4982039232876625, 2.4978864322734444, 2.4962628254821104, 2.497144330890892, 2.498514191276644, 2.4981536548063463, 2.4963938357459994, 2.4949951161538335, 2.4951106908979881, 2.4958613555329636, 2.4981289291986619, 2.4974860677594277, 2.4957707674113423, 2.4944560480312497, 2.4954138931089909, 2.493115400623692, 2.4940313333392692, 2.4933273583654811, 2.4927003579334372, 2.4900455953111251, 2.4933163671568432, 2.4917768503235878, 2.4912174811147243, 2.4908288262996487, 2.4907817992399583, 2.4932391584542031, 2.4907302035988477, 2.4921300365802357, 2.4882521442517218, 2.4887462843809907, 2.4893718284424065, 2.4882348355180608, 2.4867851340429303, 2.4871214947095996, 2.4835867776300988, 2.4867250222901198, 2.4874951220593058, 2.4858794530869295, 2.4852071258559358, 2.4820395448650636, 2.4823092787810337, 2.4779599771964294, 2.4791845723506016, 2.477108928551043, 2.4778670236685487, 2.47654199686938, 2.4785038950583931, 2.475040799203641, 2.4728145115314373, 2.4718257125647929, 2.4769530303289491, 2.4777158180737455, 2.4730855837094587, 2.4752885729661993, 2.4712323566640069, 2.4718970383297667, 2.4734310124577044, 2.4711750589989117, 2.4719086437242574, 2.471575999272726, 2.4697846874637728, 2.4705153658648564, 2.4688532195439024, 2.4702164254884296, 2.4719420675924586, 2.4699568420671407, 2.469504644932583, 2.4707574546602888, 2.4662071401260635, 2.4662311621692576, 2.4619603116509508, 2.4642646166171227, 2.4667129611943777, 2.4657735091526596, 2.461118169231026, 2.4629121177958502, 2.4590348040748702, 2.4582676307167004, 2.4598776876562143, 2.4607219308810935, 2.4581935192603162, 2.4558945694632155, 2.4564966038513671, 2.4589417459491303, 2.4576981428766782, 2.4576263925016604, 2.4551388322009498, 2.4566508398742628, 2.4542368659105107, 2.4524367052047071, 2.45178142628852, 2.4552949780968811, 2.4558424987136682, 2.4554169934897128, 2.4553103548587902, 2.4569274157219811, 2.4560511775906262, 2.4553892678085556, 2.4565021131182809, 2.4545272386897783, 2.4559748769454841, 2.4529930725071272, 2.4540310528654175, 2.4555575477941511, 2.4542092955622841, 2.4522553768748625, 2.4520992654349496, 2.4503163551405218, 2.4507708258806979, 2.45288269264426, 2.4483271088993726, 2.4456838649243968, 2.4441796565727985, 2.4458767802140633, 2.4447817316756808, 2.4450548036466886, 2.4443901938067141, 2.4445581487201729, 2.4440218169455723, 2.4434951025212257, 2.4416724830530221, 2.4402548866490315, 2.4446625472524963, 2.4408827883048358, 2.4424834167749618, 2.4428078494609444, 2.4409503591034265, 2.4410381296420072, 2.4426618631120371, 2.4400938248597721, 2.4429047584511214, 2.4385188594548102, 2.4406540272669419, 2.4401767580037168, 2.4388083549058193, 2.4386768589510828, 2.4377838552629183, 2.4367258794557261, 2.4338775293628534, 2.434110883046174, 2.4353256169833313, 2.4309685465107012, 2.4323231389452413, 2.4318592441553051, 2.4289555222997783, 2.428341411200893, 2.4256793931108067, 2.4264862697593204, 2.425459118355652, 2.4267497767420751, 2.4266449044385046, 2.4244994338410861, 2.427791628026327, 2.4240341240754528, 2.4252277299379017, 2.4267749231882703, 2.4258137214514606, 2.4254079264070838, 2.4256009575517825, 2.4252375201422653, 2.4240078365492539, 2.4247485664184381, 2.4191715970641141, 2.4177084402800286, 2.4167797849480759, 2.4142933175736312, 2.4173592905286707, 2.4165319906631977, 2.4187166935761364, 2.4147768130919758, 2.4163729451515534, 2.4177879133944944, 2.4171241540741262, 2.4196054028251273, 2.4173288270268531, 2.4168541658884992, 2.417982456482219, 2.4135986285586992, 2.4147615692918274, 2.4118447289442937, 2.4121782971050458, 2.409355290969291, 2.4083458941010538, 2.4074728522012325, 2.4056330418325005, 2.4046256444888039, 2.4064354529961056, 2.4048393660259908, 2.4072886833300764, 2.4056939016798515, 2.40596813049192, 2.4075949048941445, 2.4049916008389829, 2.4061710628567266, 2.4046424348696487, 2.4028381424176839, 2.4026902099476177, 2.4024878419756828, 2.4010249084337096, 2.398883577252898, 2.399079375015551, 2.3976259428932942, 2.3994370351325287, 2.3989649725586601, 2.3986647353776656, 2.4002633200802861, 2.4014805548818958, 2.3966766948511768, 2.3954537236533424, 2.3987389799479226, 2.394355438664352, 2.3941168935565811, 2.39686384784099, 2.3976205499634249, 2.395745933841642, 2.3945812425525372, 2.392884594591544, 2.3954295822522944, 2.3940255194807265, 2.3946973532599087, 2.3926577877205149, 2.3950174520234069, 2.392480227419453, 2.389857539007842, 2.3912854590429884, 2.3883992688670719, 2.3908834958292791, 2.3915550903958778, 2.3877242496974391, 2.3866911249839928, 2.3842309856996704, 2.3857586747225352, 2.383856214904684, 2.3827729205814894, 2.3820464237752157, 2.3818849524093553, 2.3831426098990423, 2.3810319334690622, 2.3813689835726715, 2.3823108518417211, 2.3791879347825704, 2.37907449773668, 2.3780177248592937, 2.3780015730382731, 2.3773655182692064, 2.3768346413297339, 2.3783605671318164, 2.3749216050166431, 2.3742762542911717, 2.371988778361402, 2.3739412171837677, 2.374285629602487, 2.3736258808844548, 2.3732296799928498, 2.3730243145968455, 2.3712471381869014, 2.3720779046771501, 2.3676644505425237, 2.371033017383577, 2.3684924802827467, 2.3701580729394651, 2.368183587929146, 2.3671170643465111, 2.3656463848315661, 2.3652835384039901, 2.3649967344579728, 2.3628216648684646, 2.3637369029909818, 2.3627748162265125, 2.3629257747886916, 2.3635167164122843, 2.3617756366624967, 2.3632902833557146, 2.3603235380521124, 2.3589162111974957, 2.3564076775404725, 2.3572814558069446, 2.3575787723942421, 2.3598336482543454, 2.3561518389513041, 2.3582582267144114, 2.3569050557776734, 2.3556367280292143, 2.3598859107429115, 2.3570348525140794, 2.3552765098212061, 2.356346653237599, 2.3575799133869402, 2.3547985165298693, 2.358438308310173, 2.3556338434439903, 2.3549034878401196, 2.3516814089949363, 2.3548224693125559, 2.3523973934729154, 2.3525582169967092, 2.3516620051795516, 2.3525889421081203, 2.3526167853425854, 2.3502297484347259, 2.3509800760110271, 2.3488714804137625, 2.3465479971997145, 2.3517864973221112, 2.3470238043849925, 2.3468127291455101, 2.3494346574627341, 2.3475592521792348, 2.3461834400145287, 2.3475963268300188, 2.3441964818065757, 2.3449812733282607, 2.3435179988744843, 2.3427271354087162, 2.3451719718843069, 2.3423222857306953, 2.3443652568947271, 2.3442507817914882, 2.3453463985426239, 2.3427916188305873, 2.3424037845015557, 2.3408258058628353, 2.340705479319555, 2.3409886259357298, 2.338497686154339, 2.3396918257902017, 2.3370095488175693, 2.3384533198423307, 2.3363531671618518, 2.3359213321821342, 2.3376687774637714, 2.3367156006870937, 2.3374922787114283, 2.3375235916421868, 2.3382675159825363, 2.3364063328936089, 2.3374733170731092, 2.3341625953310041, 2.335614091433285, 2.3339590855042154, 2.3333835945550274, 2.3354003175119065, 2.333926253576156, 2.3346001717657279, 2.3346439256680185, 2.3313431015807757, 2.3315182392513192, 2.3311568037853023, 2.3284888788945892, 2.3290307071970409, 2.325106459967444, 2.329137407929764, 2.3270823801388842, 2.327336614007006, 2.3286829796555537, 2.3274777576485515, 2.3296765594343447, 2.3310986270128695, 2.3239320426765695, 2.325034897593738, 2.3261905195339443, 2.3225233536197978, 2.324621604494292, 2.3251496479164011, 2.3243036199149953, 2.3222132510600173, 2.3220011975448043, 2.3228468551347818, 2.3248375748167578, 2.3273315003257524, 2.320425958736962, 2.3227825745447599, 2.3209149998623468, 2.323342251157539, 2.3207611555344094, 2.3197780034228344, 2.3181917454164584, 2.3211002276736217, 2.3199750237362382, 2.3197363109079787, 2.3193518064643786, 2.3174773405014162, 2.3195516218158834, 2.3166660158291865, 2.3178289819867142, 2.316712961624761, 2.3172574140684468, 2.3158184283703771, 2.3150919850533476, 2.3147977281071159, 2.3170142771232443, 2.3130506272839182, 2.3154421103661949, 2.3127606232209961, 2.3140721624794285, 2.3123285349341165, 2.3125772213526377, 2.311254045396073, 2.3134533834377735, 2.315004137052997, 2.3128650560568609, 2.3135705392065504, 2.3114503634005725, 2.311801982556243, 2.3144572377748203, 2.3139159539659744, 2.3133424482778286, 2.3140797479536759, 2.3137826578620708, 2.3146474763587763, 2.3143563209524909, 2.3153192405046528, 2.3128254216582604, 2.3122723216477805, 2.3121664320278725, 2.3102799942558776, 2.311886031839133, 2.3092467929853635, 2.3105145584588631, 2.31098416915186, 2.3096793961817905, 2.3112703335106577, 2.3118936882514229, 2.3112622521471775, 2.3117414690829445, 2.3065169006361432, 2.3078849801400927, 2.308358335074852, 2.3089415794237116, 2.3091123582720523, 2.311284369143809, 2.3082666456801286, 2.3057030767461648, 2.3069749384712477, 2.3057418714287845, 2.3068795863982512, 2.3026504278510904, 2.3060037430694971, 2.3067677842789238, 2.3076478032453727, 2.3059910221946143, 2.3028523084846837, 2.302755679982309, 2.30388550669768, 2.3042878520972923, 2.3026756729130944, 2.303922690596476, 2.3027344743371945, 2.3037110423316234, 2.3035411375623851, 2.3053965418794022, 2.3043063134457658, 2.304535188333162, 2.3016233321686648, 2.3005348112766564, 2.2995774593642642, 2.2991760064220412, 2.2980212064863346, 2.2986373414248988, 2.2992576606572075, 2.3003870576214971, 2.2986848298566334, 2.2999981116088373, 2.3006123223103243, 2.2994209266525809, 2.3002881923648131, 2.3015660218520373, 2.2995223059481904, 2.2995656428041871, 2.3026421867144666, 2.2998614288885153, 2.2999953626809493, 2.3020222850902048, 2.3006742715610691, 2.300193992002062, 2.2968363225963362, 2.2973415092055052, 2.2954936194397475, 2.297249480972452, 2.3001183368183105, 2.2983762180120415, 2.2955051204226131, 2.2994068488318531, 2.2962605222845425, 2.2979940696864447, 2.3020779698423426, 2.3000583203960718, 2.297851868846541, 2.2982425337249177, 2.2990277119829035, 2.298749961522442, 2.2982441186003264, 2.2987077368526427, 2.297191665655987, 2.2962272985001206, 2.2953192428886595, 2.2948044255945219, 2.2954207949310188, 2.2963841522010231, 2.2939609088818274, 2.2962906579083451, 2.2953749176696543, 2.2959042972819086, 2.2966802178272934, 2.2954022689964364, 2.2970280395297031, 2.296973123148244, 2.2968407354011893, 2.2948771031725665, 2.2944170920386453, 2.2954111673442483, 2.2962969883186117, 2.2975692034578192, 2.2941843469526986, 2.2954206866064708, 2.2953134242929356, 2.2939944765324416, 2.2971163915180743, 2.2931141566685262, 2.2928134192505705, 2.2945132699933519, 2.2921416784340152, 2.2926027276561274, 2.2915791274972332, 2.2929808700057617, 2.2917925586090875, 2.2931807517490066, 2.2918499439325379, 2.294656536794129, 2.2903843996772908, 2.2911382866540335, 2.2904974519806625, 2.2897702312877666, 2.2911246725585839, 2.2907808237832969, 2.2896332257897503, 2.2914290893373641, 2.2888734009114251, 2.2864846172436271, 2.2873104468727234, 2.2910055907382381, 2.2871051173398942, 2.2851346521781171, 2.2908672005815331, 2.2889027584442716, 2.2871518753522064, 2.2867596830848993, 2.2883785660634817, 2.2866042120489825, 2.28817920069877, 2.2889315268461385, 2.2865765569082743, 2.2855485621528784, 2.286919957686147, 2.2864879855246976, 2.28793584438414, 2.2864763690332417, 2.2857059197207232, 2.2860979263484182, 2.2876492958779324, 2.2872504920011165, 2.287417204333305, 2.2870683234986808, 2.286450753230802, 2.2850898622748983, 2.2822627854188391, 2.286911979638155, 2.2869755785453987, 2.2824761267430418, 2.2857956003906659, 2.2820265876282271, 2.2810997792727274, 2.2815641666821307, 2.2835538965411621, 2.2805643086175853, 2.2816775531981648, 2.2825254787572318, 2.2811948782264153, 2.2825146344990497, 2.282754111903103, 2.2830908596699824, 2.2812709387001231, 2.2819972410443086, 2.2818875430567309, 2.2825457949765755, 2.2824629557105252, 2.2838960229814922, 2.2803418045361354, 2.2826600731156175, 2.2813350704033679, 2.2807526587456661, 2.2828164394800421, 2.2851422417793446, 2.2809753873876679, 2.2812822991999315, 2.2823394535889605, 2.2835923142160741, 2.2829732001482363, 2.2796921617443617, 2.2841529004742158, 2.2799784011382194, 2.2820589274177463, 2.2813805541757177, 2.280464881304062, 2.2807335174580259, 2.2818007471503763, 2.2816039765172125, 2.2782487208983784, 2.2781263239720144, 2.2820882923159114, 2.2821401768336869, 2.2789695945552166, 2.2791322615876193, 2.2797851108479716, 2.2819833638823281, 2.2821960942913475, 2.2809085038292123, 2.2799407423350004, 2.2814388033860018, 2.2775334764834683, 2.2816922200030416, 2.2811528164968875, 2.2801223657401284, 2.2797202598380553]
def test_alpha_s():
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    BNMTF.initialise(init_S,init_FG)
    BNMTF.tau = 3.
    alpha_s = alpha + 6.
    assert BNMTF.alpha_s() == alpha_s
Пример #33
0
    all_R.append(R)
    
    
# We now run the VB algorithm on each of the M's for each noise ratio    
all_performances = {metric:[] for metric in metrics} 
average_performances = {metric:[] for metric in metrics} # averaged over repeats
for (noise,R,Ms,Ms_test) in zip(noise_ratios,all_R,all_Ms,all_Ms_test):
    print "Trying noise ratio %s." % noise
    
    # Run the algorithm <repeats> times and store all the performances
    for metric in metrics:
        all_performances[metric].append([])
    for (repeat,M,M_test) in zip(range(0,repeats),Ms,Ms_test):
        print "Repeat %s of noise ratio %s." % (repeat+1, noise)
    
        BNMF = bnmtf_gibbs_optimised(R,M,K,L,priors)
        BNMF.initialise(init_S,init_FG)
        BNMF.run(iterations)
    
        # Measure the performances
        performances = BNMF.predict(M_test,burn_in,thinning)
        for metric in metrics:
            # Add this metric's performance to the list of <repeat> performances for this noise ratio
            all_performances[metric][-1].append(performances[metric])
            
    # Compute the average across attempts
    for metric in metrics:
        average_performances[metric].append(sum(all_performances[metric][-1])/repeats)
    

    
def test_beta_s():
    BNMTF = bnmtf_gibbs_optimised(R,M,K,L,priors)
    BNMTF.initialise(init_S,init_FG)
    BNMTF.tau = 3.
    beta_s = beta + .5*(12*(11./15.)**2) #F*S = [[1/6+1/6=1/3,..]], F*S*G^T = [[1/15*4=4/15,..]]
    assert abs(BNMTF.beta_s() - beta_s) < 0.00000000000001