示例#1
0
文件: nmf_gibbs.py 项目: MXDC/BNMTF
burn_in = 180
thinning = 2

init_UV = 'random'
I, J, K = 100, 80, 10

alpha, beta = 1., 1.
lambdaU = numpy.ones((I, K)) / 10
lambdaV = numpy.ones((J, K)) / 10
priors = {'alpha': alpha, 'beta': beta, 'lambdaU': lambdaU, 'lambdaV': lambdaV}

# Load in data
R = numpy.loadtxt(input_folder + "R.txt")
M = numpy.ones((I, J))

M_test = calc_inverse_M(numpy.loadtxt(input_folder + "M.txt"))

# Run the Gibbs sampler
BNMF = bnmf_gibbs_optimised(R, M, K, priors)
BNMF.initialise(init_UV)
BNMF.run(iterations)

taus = BNMF.all_tau
Us = BNMF.all_U
Vs = BNMF.all_V

# Plot tau against iterations to see that it converges
f, axarr = plt.subplots(3, sharex=True)
x = range(1, len(taus) + 1)
axarr[0].set_title('Convergence of values')
axarr[0].plot(x, taus)
示例#2
0
init_S = 'random'
init_FG = 'kmeans'

metrics = ['MSE', 'R^2', 'Rp']

# Load in data
R_true = numpy.loadtxt(input_folder + "R_true.txt")

# For each noise ratio, generate mask matrices for each attempt
M_attempts = 100
all_Ms = [[
    try_generate_M(I, J, fraction_unknown, M_attempts)
    for r in range(0, repeats)
] for noise in noise_ratios]
all_Ms_test = [[calc_inverse_M(M) for M in Ms] for Ms in all_Ms]


# Make sure each M has no empty rows or columns
def check_empty_rows_columns(M, fraction):
    sums_columns = M.sum(axis=0)
    sums_rows = M.sum(axis=1)
    for i, c in enumerate(sums_rows):
        assert c != 0, "Fully unobserved row in M, row %s. Fraction %s." % (
            i, fraction)
    for j, c in enumerate(sums_columns):
        assert c != 0, "Fully unobserved column in M, column %s. Fraction %s." % (
            j, fraction)


for Ms in all_Ms:
示例#3
0
metrics = ['MSE', 'R^2', 'Rp']

#'''
# Load in data
R = numpy.loadtxt(input_folder+"R.txt")

# Seed all of the methods the same
numpy.random.seed(3)

# Generate matrices M - one list of M's for each fraction
M_attempts = 100
all_Ms = [ 
    [try_generate_M(I,J,fraction,M_attempts) for r in range(0,repeats)]
    for fraction in fractions_unknown
]
all_Ms_test = [ [calc_inverse_M(M) for M in Ms] for Ms in all_Ms ]

# Make sure each M has no empty rows or columns
def check_empty_rows_columns(M,fraction):
    sums_columns = M.sum(axis=0)
    sums_rows = M.sum(axis=1)
    for i,c in enumerate(sums_rows):
        assert c != 0, "Fully unobserved row in M, row %s. Fraction %s." % (i,fraction)
    for j,c in enumerate(sums_columns):
        assert c != 0, "Fully unobserved column in M, column %s. Fraction %s." % (j,fraction)
        
for Ms,fraction in zip(all_Ms,fractions_unknown):
    for M in Ms:
        check_empty_rows_columns(M,fraction)

示例#4
0
##########

input_folder = project_location + "BNMTF/experiments/generate_toy/bnmf/"

iterations = 200
I, J, K = 100, 80, 10

init_UV = "exponential"
expo_prior = 1 / 10.0

# Load in data
R = numpy.loadtxt(input_folder + "R.txt")
M = numpy.ones((I, J))
# M = numpy.loadtxt(input_folder+"M.txt")
M_test = calc_inverse_M(numpy.loadtxt(input_folder + "M.txt"))

# Run the VB algorithm
nmf = NMF(R, M, K)
nmf.initialise(init_UV, expo_prior)
nmf.run(iterations)

# Also measure the performances
performances = nmf.predict(M_test)
print "Performance on test set: %s." % performances

# Extract the performances across all iterations
print "np_all_performances = %s" % nmf.all_performances

"""
np_all_performances = {'R^2': [0.3032122365329063, 0.6133166405978905, 0.6431376044334463, 0.6593067085963278, 0.6717514727663727, 0.68228506665832, 0.6917125664190896, 0.7005339760645299, 0.7090723817749585, 0.7175252245809353, 0.7259934237903767, 0.7345036350654565, 0.7430317677039082, 0.75152937384811, 0.7599480851429183, 0.7682550228175197, 0.776436085576062, 0.784490285163333, 0.7924211663319218, 0.8002297289193113, 0.8079102319907545, 0.8154484610474781, 0.8228217510909723, 0.8300003544925733, 0.8369499438010377, 0.8436350246711801, 0.8500228818826858, 0.8560874492822713, 0.8618123094412313, 0.8671921091902559, 0.8722321105692831, 0.876946185389484, 0.8813539787340927, 0.8854780322176982, 0.8893414329509705, 0.8929662251700472, 0.896372547257493, 0.8995783054041361, 0.9025991611102144, 0.905448654072977, 0.9081383595637647, 0.9106780524041636, 0.9130758947401676, 0.9153386756498996, 0.9174721151698999, 0.9194812187586701, 0.9213706451927918, 0.923145040566608, 0.9248092957077478, 0.9263687005314527, 0.9278289900920716, 0.9291962964751085, 0.9304770333355508, 0.9316777441364161, 0.9328049422236668, 0.9338649637036158, 0.9348638457429038, 0.9358072355984199, 0.9367003304354385, 0.9375478448843411, 0.9383540018746, 0.9391225419794407, 0.9398567467937639, 0.9405594723975267, 0.941233189525269, 0.9418800275807032, 0.9425018200834387, 0.9431001495282554, 0.9436763900026273, 0.9442317462684834, 0.9447672883809836, 0.9452839812873185, 0.9457827092068222, 0.9462642949174859, 0.9467295143403786, 0.9471791070055592, 0.9476137830938077, 0.9480342277817988, 0.9484411035871358, 0.9488350513324426, 0.9492166902442043, 0.9495866175899556, 0.9499454081501066, 0.9502936137267446, 0.9506317628149454, 0.9509603605031547, 0.9512798886262251, 0.9515908061651354, 0.9518935498682876, 0.9521885350579884, 0.9524761565801545, 0.9527567898538667, 0.9530307919790423, 0.953298502864411, 0.953560246343582, 0.9538163312538315, 0.954067052459742, 0.9543126918115871, 0.9545535190357457, 0.9547897925610067, 0.9550217602898616, 0.9552496603275022, 0.9554737216830019, 0.9556941649570642, 0.9559112030289433, 0.9561250417519456, 0.9563358806627454, 0.9565439137050391, 0.9567493299633228, 0.956952314398204, 0.9571530485710693, 0.9573517113433512, 0.9575484795342386, 0.9577435285205056, 0.9579370327630805, 0.9581291662469268, 0.9583201028234776, 0.9585100164480038, 0.9586990813076265, 0.9588874718388863, 0.9590753626367163, 0.9592629282590452, 0.9594503429330747, 0.9596377801704293, 0.9598254122988978, 0.9600134099184604, 0.9602019412888035, 0.9603911716547099, 0.9605812625147117, 0.9607723708373238, 0.960964648228191, 0.961158240050655, 0.961353284501683, 0.9615499116448762, 0.9617482424024153, 0.9619483875083532, 0.9621504464266579, 0.9623545062388184, 0.9625606405076733, 0.9627689081263755, 0.9629793521640283, 0.9631919987224604, 0.9634068558217966, 0.9636239123357836, 0.9638431370011578, 0.964064477528489, 0.9642878598447232, 0.9645131874998351, 0.9647403412713568, 0.9649691790007782, 0.9651995356947132, 0.9654312239210366, 0.9656640345257972, 0.9658977376905069, 0.9661320843414578, 0.9663668079131659, 0.9666016264572074, 0.9668362450760232, 0.9670703586492495, 0.9673036548084315, 0.9675358171052365, 0.9677665283091601, 0.9679954737637889, 0.9682223447264301, 0.9684468416146467, 0.9686686770850869, 0.9688875788749307, 0.969103292344028, 0.9693155826659883, 0.9695242366285453, 0.9697290640168207, 0.9699298985669633, 0.9701265984913543, 0.9703190465894936, 0.9705071499702334, 0.9706908394207924, 0.970870068465593, 0.9710448121632769, 0.9712150656932355, 0.9713808427837262, 0.9715421740323554, 0.9716991051667034, 0.9718516952884712, 0.9720000151391545, 0.9721441454192552, 0.9722841751867779, 0.9724202003545587, 0.972552322300062, 0.972680646595887, 0.9728052818644556, 0.9729263387563182, 0.9730439290482211, 0.9731581648545349, 0.9732691579438102, 0.9733770191510263, 0.9734818578754834, 0.9735837816541444, 0.9736828958005008, 0.9737793030996176, 0.973873103550837], 'MSE': [27.399257558274289, 15.20525691358168, 14.032629735542894, 13.396824297110738, 12.907468259842803, 12.493263724126439, 12.122553290305664, 11.775675679109096, 11.439926417335327, 11.107541680858105, 10.774553096156563, 10.439912502812888, 10.104567201268855, 9.7704222699410597, 9.4393796610303511, 9.1127322415429362, 8.7910345060569508, 8.4743253149837283, 8.162465280701575, 7.8554150873583133, 7.5534004813834965, 7.2569804087480039, 6.967046108024455, 6.6847673226874669, 6.411493885106732, 6.148621512538762, 5.8974366408173555, 5.6589642495486725, 5.4338499088490657, 5.2223041173414577, 5.0241199598142039, 4.8387519655414986, 4.6654276457921826, 4.503260613726126, 4.3513430323738893, 4.2088080729315367, 4.0748638490318934, 3.9488062490930589, 3.8300194278013406, 3.7179709738951754, 3.6122057217453727, 3.5123393034449846, 3.4180507651081431, 3.3290731449570705, 3.2451815183265595, 3.1661790593625017, 3.0918825745213767, 3.022109368953747, 2.9566671243484737, 2.8953478293744874, 2.8379259688326601, 2.7841604153684467, 2.7337989696837179, 2.6865843281374646, 2.6422603713588617, 2.6005779494375738, 2.5612996674416277, 2.5242034625649303, 2.4890849706131246, 2.4557588017460175, 2.424058900905905, 2.3938381803645279, 2.3649676004506817, 2.3373348536950753, 2.3108427852744402, 2.2854076622591495, 2.2609573865460915, 2.2374297308914373, 2.2147706630940118, 2.1929328092128246, 2.1718740922791593, 2.1515565684053737, 2.1319454681042029, 2.1130084378997616, 2.094714966835455, 2.0770359749308032, 2.0599435362848069, 2.0434107082139512, 2.0274114390402396, 2.0119205301809138, 1.9969136322618262, 1.9823672593840194, 1.9682588098926799, 1.9545665856922183, 1.9412698051709747, 1.9283486071185023, 1.9157840447077645, 1.9035580697773249, 1.8916535084005006, 1.8800540291727479, 1.8687441058670919, 1.8577089761633205, 1.8469345980918024, 1.8364076056789749, 1.8261152650610468, 1.8160454320636972, 1.8061865119502274, 1.7965274217357818, 1.7870575551742123, 1.7777667502659633, 1.768645258929094, 1.7596837183334506, 1.7508731233286758, 1.7422047994003589, 1.7336703756587923, 1.7252617574903006, 1.7169710986653441, 1.7087907728828211, 1.7007133449162777, 1.6927315416998205, 1.6848382238325299, 1.6770263580817495, 1.6692889915204125, 1.6616192279404429, 1.6540102071468659, 1.6464550876607285, 1.638947033253777, 1.63147920361462, 1.6240447493150174, 1.6166368111189668, 1.6092485235621385, 1.6018730226352913, 1.5945034573341075, 1.5871330047923813, 1.5797548886949737, 1.572362400668047, 1.5649489243633816, 1.5575079619855172, 1.5500331630499766, 1.5425183552027111, 1.5349575769697661, 1.5273451123385378, 1.5196755270943338, 1.5119437068446611, 1.5041448966581834, 1.496274742223592, 1.4883293323945916, 1.4803052429317463, 1.4721995811792468, 1.4640100313262552, 1.4557348997990296, 1.4473731602151247, 1.4389244972052531, 1.4303893482786443, 1.4217689427769382, 1.4130653368377961, 1.4042814431798316, 1.395421054434353, 1.3864888586961726, 1.377490445956735, 1.3684323041260125, 1.3593218034555188, 1.350167168347705, 1.3409774357809285, 1.3317623998919894, 1.3225325426335526, 1.3132989508499662, 1.3040732205746344, 1.2948673498245507, 1.2856936216279229, 1.2765644794428934, 1.2674923974843511, 1.2584897487481586, 1.2495686736894238, 1.2407409525614406, 1.23201788434924, 1.2234101750375042, 1.2149278376480281, 1.2065801060811374, 1.1983753643213095, 1.1903210920441649, 1.1824238271172158, 1.1746891449476184, 1.1671216541218847, 1.1597250073282843, 1.1525019261685885, 1.1454542381667523, 1.1385829240729108, 1.1318881734442128, 1.1253694464547834, 1.1190255399381208, 1.1128546557832109, 1.1068544699786291, 1.1010222008101029, 1.0953546749529086, 1.0898483904465743, 1.0844995757832876, 1.0793042445737724, 1.0742582454666738, 1.0693573071847469, 1.0645970787001886, 1.0599731647006541, 1.0554811565976501, 1.0511166594010992, 1.0468753148310463, 1.0427528210616563, 1.0387449494983683, 1.034847558978512, 1.0310566077628192, 1.0273681636528726], 'Rp': [0.71512381537293201, 0.78597849089284466, 0.80377040312125636, 0.81334583838563035, 0.82052008049678993, 0.82657421038573053, 0.83202326012432404, 0.8371542244980632, 0.84214316183087712, 0.84709379671983231, 0.85205523586651666, 0.85703463102892763, 0.86201093305974008, 0.8669506254114846, 0.87182221366857859, 0.87660505338774353, 0.88129075447890504, 0.88587928555038986, 0.89037354341129082, 0.89477503847450479, 0.89908144546019331, 0.90328568609303683, 0.90737606322609299, 0.9113371653681035, 0.91515139153118608, 0.91880094027832948, 0.92227002110764256, 0.925546926331551, 0.92862552426576772, 0.93150581360650297, 0.93419344354135314, 0.93669842851977159, 0.93903349119423063, 0.94121246939821013, 0.94324907488007914, 0.94515610269101924, 0.94694504452279127, 0.94862598661559816, 0.95020766399010603, 0.95169757341970473, 0.95310209275261004, 0.95442659434562371, 0.95567556408930465, 0.95685274242061091, 0.95796129476174885, 0.95900400437735478, 0.95998346849743288, 0.96090227339400358, 0.96176312688951182, 0.96256893552625378, 0.96332282478432618, 0.96402811068901195, 0.96468823742485876, 0.96530669740870023, 0.96588694842463019, 0.96643233847388699, 0.96694604455401278, 0.96743102778100409, 0.96789000459976238, 0.96832543229579471, 0.96873950636706729, 0.96913416721989565, 0.96951111384692057, 0.96987182244771131, 0.97021756826160133, 0.97054944915782926, 0.97086840976155553, 0.97117526509802121, 0.97147072292498671, 0.97175540410903805, 0.97202986059052998, 0.97229459067261648, 0.97255005155395347, 0.97279666919060004, 0.9730348457076714, 0.97326496467791712, 0.97348739463749789, 0.97370249122339148, 0.97391059829682458, 0.97411204837454746, 0.97430716263381678, 0.9744962506975956, 0.97467961035005279, 0.97485752728354735, 0.97503027493874439, 0.9751981144692935, 0.97536129484069811, 0.97552005305844114, 0.9756746145107541, 0.97582519340628837, 0.97597199328420092, 0.97611520757383186, 0.97625502018227517, 0.97639160609012488, 0.97652513193894996, 0.97665575659748682, 0.97678363169752269, 0.97690890213449544, 0.97703170653158888, 0.97715217766970741, 0.97727044288808629, 0.97738662446238889, 0.97750083996796688, 0.9776132026357256, 0.97772382170727035, 0.97783280279431639, 0.97794024824493742, 0.97804625751700502, 0.97815092755664912, 0.9782543531772393, 0.97835662743262508, 0.97845784197685504, 0.97855808740231209, 0.97865745354755695, 0.97875602976717369, 0.97885390515667281, 0.97895116872698307, 0.97904790952470111, 0.97914421669594753, 0.97924017949340969, 0.97933588722766496, 0.97943142916489589, 0.97952689437441842, 0.97962237152963094, 0.97971794866657924, 0.97981371290398489, 0.97990975012864823, 0.98000614464941005, 0.98010297882253772, 0.98020033265071238, 0.98029828335732028, 0.98039690493741183, 0.98049626768616827, 0.9805964377058628, 0.98069747639213822, 0.98079943990093699, 0.98090237859769924, 0.98100633649133684, 0.98111135065633126, 0.98121745064756094, 0.98132465791371504, 0.9814329852166439, 0.98154243606579927, 0.98165300417829171, 0.98176467297713865, 0.98187741514164317, 0.98199119222529807, 0.98210595435787218, 0.98222164004878687, 0.98233817610919094, 0.98245547770950548, 0.98257344858769857, 0.98269198142163638, 0.98281095837511034, 0.98293025182377347, 0.98304972526164114, 0.98316923438379389, 0.98328862833456931, 0.98340775110480849, 0.98352644305513437, 0.98364454253751321, 0.98376188758200633, 0.98387831761255629, 0.98399367515335456, 0.98410780748657034, 0.98422056822347126, 0.98433181875321962, 0.98444142953784319, 0.98454928122694796, 0.98465526557202088, 0.98475928612691854, 0.98486125872826813, 0.98496111175656631, 0.9850587861851805, 0.98515423543072245, 0.98524742502282669, 0.98533833211559185, 0.98542694486548921, 0.98551326170203901, 0.98559729051808975, 0.98567904780546234, 0.98575855776085874, 0.98583585138377949, 0.98591096558626379, 0.98598394233059561, 0.98605482780829279, 0.98612367167012538, 0.9861905263143651, 0.9862554462372024, 0.98631848744719386, 0.98637970694335364, 0.98643916225483497, 0.98649691103895787, 0.98655301073320478, 0.98660751825644044, 0.98666048975406639, 0.98671198038194607, 0.98676204412391155, 0.98681073363811789, 0.98685810012779507]}
R = numpy.dot(F,numpy.dot(S,G.T))

iterations = 50
init_S = 'random'
init_FG = 'kmeans'
I, J, K, L = 20, 20, 3, 2
fraction_unknown = 0.6

alpha, beta = 1., 1.
lambdaF = numpy.ones((I,K))*1
lambdaS = numpy.ones((K,L))/50
lambdaG = numpy.ones((J,L))*1
priors = { 'alpha':alpha, 'beta':beta, 'lambdaF':lambdaF, 'lambdaS':lambdaS, 'lambdaG':lambdaG }

M = generate_M(I,J,fraction_unknown)
M_test = calc_inverse_M(M)

# Run the Gibbs sampler
BNMTF = bnmtf_vb(R,M,K,L,priors)
BNMTF.initialise(init_S=init_S,init_FG=init_FG)
BNMTF.run(iterations)

# Also measure the performances
performances = BNMTF.predict(M_test)
print performances

# Plot the tau expectation values to check convergence
plt.plot(BNMTF.all_exp_tau)

print "F: %s." % BNMTF.expF
print "S: %s." % BNMTF.expS
示例#6
0
fraction_unknown = 0.6

alpha, beta = 1., 1.
lambdaF = numpy.ones((I, K)) * 1
lambdaS = numpy.ones((K, L)) / 50
lambdaG = numpy.ones((J, L)) * 1
priors = {
    'alpha': alpha,
    'beta': beta,
    'lambdaF': lambdaF,
    'lambdaS': lambdaS,
    'lambdaG': lambdaG
}

M = generate_M(I, J, fraction_unknown)
M_test = calc_inverse_M(M)

# Run the Gibbs sampler
BNMTF = bnmtf_vb(R, M, K, L, priors)
BNMTF.initialise(init_S=init_S, init_FG=init_FG)
BNMTF.run(iterations)

# Also measure the performances
performances = BNMTF.predict(M_test)
print performances

# Plot the tau expectation values to check convergence
plt.plot(BNMTF.all_exp_tau)

print "F: %s." % BNMTF.expF
print "S: %s." % BNMTF.expS