def test_predict(): (I, J, K) = (5, 3, 2) R = numpy.array( [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]], dtype=float) M = numpy.ones((I, J)) K = 3 U = numpy.array([[125., 126.], [126., 126.], [126., 126.], [126., 126.], [126., 126.]]) V = numpy.array([[84., 84.], [84., 84.], [84., 84.]]) M_test = numpy.array([[0, 0, 1], [0, 1, 0], [0, 0, 0], [1, 1, 0], [0, 0, 0]]) #R->3,5,10,11, R_pred->21084,21168,21168,21168 MSE = (444408561. + 447872569. + 447660964. + 447618649) / 4. R2 = 1. - (444408561. + 447872569. + 447660964. + 447618649) / ( 4.25**2 + 2.25**2 + 2.75**2 + 3.75**2) #mean=7.25 Rp = 357. / ( math.sqrt(44.75) * math.sqrt(5292.) ) #mean=7.25,var=44.75, mean_pred=21147,var_pred=5292, corr=(-4.25*-63 + -2.25*21 + 2.75*21 + 3.75*21) nmf = NMF(R, M, K) nmf.U = U nmf.V = V performances = nmf.predict(M_test) assert performances['MSE'] == MSE assert performances['R^2'] == R2 assert performances['Rp'] == Rp
def test_run(): # Data generated from W = [[1,2],[3,4]], H = [[4,3],[2,1]] R = [[8, 5], [20, 13]] M = [[1, 1], [1, 0]] K = 2 U = numpy.array([[10, 9], [8, 7]], dtype='f') #2x2 V = numpy.array([[6, 4], [5, 3]], dtype='f') #2x2 nmf = NMF(R, M, K) # Check we get an Exception if W, H are undefined with pytest.raises(AssertionError) as error: nmf.run(0) assert str( error.value ) == "U and V have not been initialised - please run NMF.initialise() first." # Then check for 1 iteration whether the updates work - heck just the first entry of U nmf.U = U nmf.V = V nmf.run(1) U_00 = 10 * (6 * 8 / 96.0 + 5 * 5 / 77.0) / (5.0 + 6.0) #0.74970484061 assert abs(U_00 - nmf.U[0][0]) < 0.000001
def test_compute_statistics(): R = numpy.array([[1,2],[3,4]],dtype=float) M = numpy.array([[1,1],[0,1]]) (I,J,K) = 2, 2, 3 nmf = NMF(R,M,K) R_pred = numpy.array([[500,550],[1220,1342]],dtype=float) M_pred = numpy.array([[0,0],[1,1]]) MSE_pred = (1217**2 + 1338**2) / 2.0 R2_pred = 1. - (1217**2+1338**2)/(0.5**2+0.5**2) #mean=3.5 Rp_pred = 61. / ( math.sqrt(.5) * math.sqrt(7442.) ) #mean=3.5,var=0.5,mean_pred=1281,var_pred=7442,cov=61 assert MSE_pred == nmf.compute_MSE(M_pred,R,R_pred) assert R2_pred == nmf.compute_R2(M_pred,R,R_pred) assert Rp_pred == nmf.compute_Rp(M_pred,R,R_pred)
def test_compute_statistics(): R = numpy.array([[1, 2], [3, 4]], dtype=float) M = numpy.array([[1, 1], [0, 1]]) (I, J, K) = 2, 2, 3 nmf = NMF(R, M, K) R_pred = numpy.array([[500, 550], [1220, 1342]], dtype=float) M_pred = numpy.array([[0, 0], [1, 1]]) MSE_pred = (1217**2 + 1338**2) / 2.0 R2_pred = 1. - (1217**2 + 1338**2) / (0.5**2 + 0.5**2) #mean=3.5 Rp_pred = 61. / (math.sqrt(.5) * math.sqrt(7442.) ) #mean=3.5,var=0.5,mean_pred=1281,var_pred=7442,cov=61 assert MSE_pred == nmf.compute_MSE(M_pred, R, R_pred) assert R2_pred == nmf.compute_R2(M_pred, R, R_pred) assert Rp_pred == nmf.compute_Rp(M_pred, R, R_pred)
def test_predict(): (I,J,K) = (5,3,2) R = numpy.array([[1,2,3],[4,5,6],[7,8,9],[10,11,12],[13,14,15]],dtype=float) M = numpy.ones((I,J)) K = 3 U = numpy.array([[125.,126.],[126.,126.],[126.,126.],[126.,126.],[126.,126.]]) V = numpy.array([[84.,84.],[84.,84.],[84.,84.]]) M_test = numpy.array([[0,0,1],[0,1,0],[0,0,0],[1,1,0],[0,0,0]]) #R->3,5,10,11, R_pred->21084,21168,21168,21168 MSE = (444408561. + 447872569. + 447660964. + 447618649) / 4. R2 = 1. - (444408561. + 447872569. + 447660964. + 447618649) / (4.25**2+2.25**2+2.75**2+3.75**2) #mean=7.25 Rp = 357. / ( math.sqrt(44.75) * math.sqrt(5292.) ) #mean=7.25,var=44.75, mean_pred=21147,var_pred=5292, corr=(-4.25*-63 + -2.25*21 + 2.75*21 + 3.75*21) nmf = NMF(R,M,K) nmf.U = U nmf.V = V performances = nmf.predict(M_test) assert performances['MSE'] == MSE assert performances['R^2'] == R2 assert performances['Rp'] == Rp
def test_compute_I_div(): R = [[1,2,0,4],[5,0,7,0]] M = [[1,1,0,1],[1,0,1,0]] K = 2 U = numpy.array([[1,2],[3,4]],dtype='f') #2x2 V = numpy.array([[5,7,9,11],[6,8,10,12]],dtype='f').T #4x2 #R_pred = [[17,23,29,35],[39,53,67,81]] expected_I_div = sum([ 1.0*math.log(1.0/17.0) - 1.0 + 17.0, 2.0*math.log(2.0/23.0) - 2.0 + 23.0, 4.0*math.log(4.0/35.0) - 4.0 + 35.0, 5.0*math.log(5.0/39.0) - 5.0 + 39.0, 7.0*math.log(7.0/67.0) - 7.0 + 67.0 ]) nmf = NMF(R,M,K) nmf.U = U nmf.V = V I_div = nmf.compute_I_div() assert I_div == expected_I_div
def test_compute_I_div(): R = [[1, 2, 0, 4], [5, 0, 7, 0]] M = [[1, 1, 0, 1], [1, 0, 1, 0]] K = 2 U = numpy.array([[1, 2], [3, 4]], dtype='f') #2x2 V = numpy.array([[5, 7, 9, 11], [6, 8, 10, 12]], dtype='f').T #4x2 #R_pred = [[17,23,29,35],[39,53,67,81]] expected_I_div = sum([ 1.0 * math.log(1.0 / 17.0) - 1.0 + 17.0, 2.0 * math.log(2.0 / 23.0) - 2.0 + 23.0, 4.0 * math.log(4.0 / 35.0) - 4.0 + 35.0, 5.0 * math.log(5.0 / 39.0) - 5.0 + 39.0, 7.0 * math.log(7.0 / 67.0) - 7.0 + 67.0 ]) nmf = NMF(R, M, K) nmf.U = U nmf.V = V I_div = nmf.compute_I_div() assert I_div == expected_I_div
def test_init(): # Test getting an exception when R and M are different sizes, and when R is not a 2D array R1 = numpy.ones(3) M = numpy.ones((2, 3)) K = 0 with pytest.raises(AssertionError) as error: NMF(R1, M, K) assert str( error.value ) == "Input matrix R is not a two-dimensional array, but instead 1-dimensional." R2 = numpy.ones((4, 3, 2)) with pytest.raises(AssertionError) as error: NMF(R2, M, K) assert str( error.value ) == "Input matrix R is not a two-dimensional array, but instead 3-dimensional." R3 = numpy.ones((3, 2)) with pytest.raises(AssertionError) as error: NMF(R3, M, K) assert str( error.value ) == "Input matrix R is not of the same size as the indicator matrix M: (3, 2) and (2, 3) respectively." # Test getting an exception if a row or column is entirely unknown R = numpy.ones((2, 3)) M1 = [[1, 1, 1], [0, 0, 0]] M2 = [[1, 1, 0], [1, 0, 0]] with pytest.raises(AssertionError) as error: NMF(R, M1, K) assert str(error.value) == "Fully unobserved row in R, row 1." with pytest.raises(AssertionError) as error: NMF(R, M2, K) assert str(error.value) == "Fully unobserved column in R, column 2." # Test whether we made a copy of R with 1's at unknown values I, J = 2, 4 R = [[1, 2, 0, 4], [5, 0, 7, 0]] M = [[1, 1, 0, 1], [1, 0, 1, 0]] K = 2 R_excl_unknown = [[1, 2, 1, 4], [5, 1, 7, 1]] nmf = NMF(R, M, K) assert numpy.array_equal(R, nmf.R) assert numpy.array_equal(M, nmf.M) assert nmf.I == I assert nmf.J == J assert nmf.K == K assert numpy.array_equal(R_excl_unknown, nmf.R_excl_unknown)
def test_initialisation(): I, J = 2, 3 R = numpy.ones((I, J)) M = numpy.ones((I, J)) K = 4 # Init ones init_UV = 'ones' nmf = NMF(R, M, K) nmf.initialise(init_UV) assert numpy.array_equal(numpy.ones((2, 4)), nmf.U) assert numpy.array_equal(numpy.ones((3, 4)), nmf.V) # Init random init_UV = 'random' nmf = NMF(R, M, K) nmf.initialise(init_UV) for (i, k) in itertools.product(range(0, I), range(0, K)): assert nmf.U[i, k] > 0 and nmf.U[i, k] < 1 for (j, k) in itertools.product(range(0, J), range(0, K)): assert nmf.V[j, k] > 0 and nmf.V[j, k] < 1
def test_run(): # Data generated from W = [[1,2],[3,4]], H = [[4,3],[2,1]] R = [[8,5],[20,13]] M = [[1,1],[1,0]] K = 2 U = numpy.array([[10,9],[8,7]],dtype='f') #2x2 V = numpy.array([[6,4],[5,3]],dtype='f') #2x2 nmf = NMF(R,M,K) # Check we get an Exception if W, H are undefined with pytest.raises(AssertionError) as error: nmf.run(0) assert str(error.value) == "U and V have not been initialised - please run NMF.initialise() first." # Then check for 1 iteration whether the updates work - heck just the first entry of U nmf.U = U nmf.V = V nmf.run(1) U_00 = 10*(6*8/96.0+5*5/77.0)/(5.0+6.0) #0.74970484061 assert abs(U_00 - nmf.U[0][0]) < 0.000001
def test_initialisation(): I,J = 2,3 R = numpy.ones((I,J)) M = numpy.ones((I,J)) K = 4 # Init ones init_UV = 'ones' nmf = NMF(R,M,K) nmf.initialise(init_UV) assert numpy.array_equal(numpy.ones((2,4)),nmf.U) assert numpy.array_equal(numpy.ones((3,4)),nmf.V) # Init random init_UV = 'random' nmf = NMF(R,M,K) nmf.initialise(init_UV) for (i,k) in itertools.product(range(0,I),range(0,K)): assert nmf.U[i,k] > 0 and nmf.U[i,k] < 1 for (j,k) in itertools.product(range(0,J),range(0,K)): assert nmf.V[j,k] > 0 and nmf.V[j,k] < 1
input_folder = project_location + "BNMTF/experiments/generate_toy/bnmf/" iterations = 200 I, J, K = 100, 80, 10 init_UV = "exponential" expo_prior = 1 / 10.0 # Load in data R = numpy.loadtxt(input_folder + "R.txt") M = numpy.ones((I, J)) # M = numpy.loadtxt(input_folder+"M.txt") M_test = calc_inverse_M(numpy.loadtxt(input_folder + "M.txt")) # Run the VB algorithm nmf = NMF(R, M, K) nmf.initialise(init_UV, expo_prior) nmf.run(iterations) # Also measure the performances performances = nmf.predict(M_test) print "Performance on test set: %s." % performances # Extract the performances across all iterations print "np_all_performances = %s" % nmf.all_performances """ np_all_performances = {'R^2': [0.3032122365329063, 0.6133166405978905, 0.6431376044334463, 0.6593067085963278, 0.6717514727663727, 0.68228506665832, 0.6917125664190896, 0.7005339760645299, 0.7090723817749585, 0.7175252245809353, 0.7259934237903767, 0.7345036350654565, 0.7430317677039082, 0.75152937384811, 0.7599480851429183, 0.7682550228175197, 0.776436085576062, 0.784490285163333, 0.7924211663319218, 0.8002297289193113, 0.8079102319907545, 0.8154484610474781, 0.8228217510909723, 0.8300003544925733, 0.8369499438010377, 0.8436350246711801, 0.8500228818826858, 0.8560874492822713, 0.8618123094412313, 0.8671921091902559, 0.8722321105692831, 0.876946185389484, 0.8813539787340927, 0.8854780322176982, 0.8893414329509705, 0.8929662251700472, 0.896372547257493, 0.8995783054041361, 0.9025991611102144, 0.905448654072977, 0.9081383595637647, 0.9106780524041636, 0.9130758947401676, 0.9153386756498996, 0.9174721151698999, 0.9194812187586701, 0.9213706451927918, 0.923145040566608, 0.9248092957077478, 0.9263687005314527, 0.9278289900920716, 0.9291962964751085, 0.9304770333355508, 0.9316777441364161, 0.9328049422236668, 0.9338649637036158, 0.9348638457429038, 0.9358072355984199, 0.9367003304354385, 0.9375478448843411, 0.9383540018746, 0.9391225419794407, 0.9398567467937639, 0.9405594723975267, 0.941233189525269, 0.9418800275807032, 0.9425018200834387, 0.9431001495282554, 0.9436763900026273, 0.9442317462684834, 0.9447672883809836, 0.9452839812873185, 0.9457827092068222, 0.9462642949174859, 0.9467295143403786, 0.9471791070055592, 0.9476137830938077, 0.9480342277817988, 0.9484411035871358, 0.9488350513324426, 0.9492166902442043, 0.9495866175899556, 0.9499454081501066, 0.9502936137267446, 0.9506317628149454, 0.9509603605031547, 0.9512798886262251, 0.9515908061651354, 0.9518935498682876, 0.9521885350579884, 0.9524761565801545, 0.9527567898538667, 0.9530307919790423, 0.953298502864411, 0.953560246343582, 0.9538163312538315, 0.954067052459742, 0.9543126918115871, 0.9545535190357457, 0.9547897925610067, 0.9550217602898616, 0.9552496603275022, 0.9554737216830019, 0.9556941649570642, 0.9559112030289433, 0.9561250417519456, 0.9563358806627454, 0.9565439137050391, 0.9567493299633228, 0.956952314398204, 0.9571530485710693, 0.9573517113433512, 0.9575484795342386, 0.9577435285205056, 0.9579370327630805, 0.9581291662469268, 0.9583201028234776, 0.9585100164480038, 0.9586990813076265, 0.9588874718388863, 0.9590753626367163, 0.9592629282590452, 0.9594503429330747, 0.9596377801704293, 0.9598254122988978, 0.9600134099184604, 0.9602019412888035, 0.9603911716547099, 0.9605812625147117, 0.9607723708373238, 0.960964648228191, 0.961158240050655, 0.961353284501683, 0.9615499116448762, 0.9617482424024153, 0.9619483875083532, 0.9621504464266579, 0.9623545062388184, 0.9625606405076733, 0.9627689081263755, 0.9629793521640283, 0.9631919987224604, 0.9634068558217966, 0.9636239123357836, 0.9638431370011578, 0.964064477528489, 0.9642878598447232, 0.9645131874998351, 0.9647403412713568, 0.9649691790007782, 0.9651995356947132, 0.9654312239210366, 0.9656640345257972, 0.9658977376905069, 0.9661320843414578, 0.9663668079131659, 0.9666016264572074, 0.9668362450760232, 0.9670703586492495, 0.9673036548084315, 0.9675358171052365, 0.9677665283091601, 0.9679954737637889, 0.9682223447264301, 0.9684468416146467, 0.9686686770850869, 0.9688875788749307, 0.969103292344028, 0.9693155826659883, 0.9695242366285453, 0.9697290640168207, 0.9699298985669633, 0.9701265984913543, 0.9703190465894936, 0.9705071499702334, 0.9706908394207924, 0.970870068465593, 0.9710448121632769, 0.9712150656932355, 0.9713808427837262, 0.9715421740323554, 0.9716991051667034, 0.9718516952884712, 0.9720000151391545, 0.9721441454192552, 0.9722841751867779, 0.9724202003545587, 0.972552322300062, 0.972680646595887, 0.9728052818644556, 0.9729263387563182, 0.9730439290482211, 0.9731581648545349, 0.9732691579438102, 0.9733770191510263, 0.9734818578754834, 0.9735837816541444, 0.9736828958005008, 0.9737793030996176, 0.973873103550837], 'MSE': [27.399257558274289, 15.20525691358168, 14.032629735542894, 13.396824297110738, 12.907468259842803, 12.493263724126439, 12.122553290305664, 11.775675679109096, 11.439926417335327, 11.107541680858105, 10.774553096156563, 10.439912502812888, 10.104567201268855, 9.7704222699410597, 9.4393796610303511, 9.1127322415429362, 8.7910345060569508, 8.4743253149837283, 8.162465280701575, 7.8554150873583133, 7.5534004813834965, 7.2569804087480039, 6.967046108024455, 6.6847673226874669, 6.411493885106732, 6.148621512538762, 5.8974366408173555, 5.6589642495486725, 5.4338499088490657, 5.2223041173414577, 5.0241199598142039, 4.8387519655414986, 4.6654276457921826, 4.503260613726126, 4.3513430323738893, 4.2088080729315367, 4.0748638490318934, 3.9488062490930589, 3.8300194278013406, 3.7179709738951754, 3.6122057217453727, 3.5123393034449846, 3.4180507651081431, 3.3290731449570705, 3.2451815183265595, 3.1661790593625017, 3.0918825745213767, 3.022109368953747, 2.9566671243484737, 2.8953478293744874, 2.8379259688326601, 2.7841604153684467, 2.7337989696837179, 2.6865843281374646, 2.6422603713588617, 2.6005779494375738, 2.5612996674416277, 2.5242034625649303, 2.4890849706131246, 2.4557588017460175, 2.424058900905905, 2.3938381803645279, 2.3649676004506817, 2.3373348536950753, 2.3108427852744402, 2.2854076622591495, 2.2609573865460915, 2.2374297308914373, 2.2147706630940118, 2.1929328092128246, 2.1718740922791593, 2.1515565684053737, 2.1319454681042029, 2.1130084378997616, 2.094714966835455, 2.0770359749308032, 2.0599435362848069, 2.0434107082139512, 2.0274114390402396, 2.0119205301809138, 1.9969136322618262, 1.9823672593840194, 1.9682588098926799, 1.9545665856922183, 1.9412698051709747, 1.9283486071185023, 1.9157840447077645, 1.9035580697773249, 1.8916535084005006, 1.8800540291727479, 1.8687441058670919, 1.8577089761633205, 1.8469345980918024, 1.8364076056789749, 1.8261152650610468, 1.8160454320636972, 1.8061865119502274, 1.7965274217357818, 1.7870575551742123, 1.7777667502659633, 1.768645258929094, 1.7596837183334506, 1.7508731233286758, 1.7422047994003589, 1.7336703756587923, 1.7252617574903006, 1.7169710986653441, 1.7087907728828211, 1.7007133449162777, 1.6927315416998205, 1.6848382238325299, 1.6770263580817495, 1.6692889915204125, 1.6616192279404429, 1.6540102071468659, 1.6464550876607285, 1.638947033253777, 1.63147920361462, 1.6240447493150174, 1.6166368111189668, 1.6092485235621385, 1.6018730226352913, 1.5945034573341075, 1.5871330047923813, 1.5797548886949737, 1.572362400668047, 1.5649489243633816, 1.5575079619855172, 1.5500331630499766, 1.5425183552027111, 1.5349575769697661, 1.5273451123385378, 1.5196755270943338, 1.5119437068446611, 1.5041448966581834, 1.496274742223592, 1.4883293323945916, 1.4803052429317463, 1.4721995811792468, 1.4640100313262552, 1.4557348997990296, 1.4473731602151247, 1.4389244972052531, 1.4303893482786443, 1.4217689427769382, 1.4130653368377961, 1.4042814431798316, 1.395421054434353, 1.3864888586961726, 1.377490445956735, 1.3684323041260125, 1.3593218034555188, 1.350167168347705, 1.3409774357809285, 1.3317623998919894, 1.3225325426335526, 1.3132989508499662, 1.3040732205746344, 1.2948673498245507, 1.2856936216279229, 1.2765644794428934, 1.2674923974843511, 1.2584897487481586, 1.2495686736894238, 1.2407409525614406, 1.23201788434924, 1.2234101750375042, 1.2149278376480281, 1.2065801060811374, 1.1983753643213095, 1.1903210920441649, 1.1824238271172158, 1.1746891449476184, 1.1671216541218847, 1.1597250073282843, 1.1525019261685885, 1.1454542381667523, 1.1385829240729108, 1.1318881734442128, 1.1253694464547834, 1.1190255399381208, 1.1128546557832109, 1.1068544699786291, 1.1010222008101029, 1.0953546749529086, 1.0898483904465743, 1.0844995757832876, 1.0793042445737724, 1.0742582454666738, 1.0693573071847469, 1.0645970787001886, 1.0599731647006541, 1.0554811565976501, 1.0511166594010992, 1.0468753148310463, 1.0427528210616563, 1.0387449494983683, 1.034847558978512, 1.0310566077628192, 1.0273681636528726], 'Rp': [0.71512381537293201, 0.78597849089284466, 0.80377040312125636, 0.81334583838563035, 0.82052008049678993, 0.82657421038573053, 0.83202326012432404, 0.8371542244980632, 0.84214316183087712, 0.84709379671983231, 0.85205523586651666, 0.85703463102892763, 0.86201093305974008, 0.8669506254114846, 0.87182221366857859, 0.87660505338774353, 0.88129075447890504, 0.88587928555038986, 0.89037354341129082, 0.89477503847450479, 0.89908144546019331, 0.90328568609303683, 0.90737606322609299, 0.9113371653681035, 0.91515139153118608, 0.91880094027832948, 0.92227002110764256, 0.925546926331551, 0.92862552426576772, 0.93150581360650297, 0.93419344354135314, 0.93669842851977159, 0.93903349119423063, 0.94121246939821013, 0.94324907488007914, 0.94515610269101924, 0.94694504452279127, 0.94862598661559816, 0.95020766399010603, 0.95169757341970473, 0.95310209275261004, 0.95442659434562371, 0.95567556408930465, 0.95685274242061091, 0.95796129476174885, 0.95900400437735478, 0.95998346849743288, 0.96090227339400358, 0.96176312688951182, 0.96256893552625378, 0.96332282478432618, 0.96402811068901195, 0.96468823742485876, 0.96530669740870023, 0.96588694842463019, 0.96643233847388699, 0.96694604455401278, 0.96743102778100409, 0.96789000459976238, 0.96832543229579471, 0.96873950636706729, 0.96913416721989565, 0.96951111384692057, 0.96987182244771131, 0.97021756826160133, 0.97054944915782926, 0.97086840976155553, 0.97117526509802121, 0.97147072292498671, 0.97175540410903805, 0.97202986059052998, 0.97229459067261648, 0.97255005155395347, 0.97279666919060004, 0.9730348457076714, 0.97326496467791712, 0.97348739463749789, 0.97370249122339148, 0.97391059829682458, 0.97411204837454746, 0.97430716263381678, 0.9744962506975956, 0.97467961035005279, 0.97485752728354735, 0.97503027493874439, 0.9751981144692935, 0.97536129484069811, 0.97552005305844114, 0.9756746145107541, 0.97582519340628837, 0.97597199328420092, 0.97611520757383186, 0.97625502018227517, 0.97639160609012488, 0.97652513193894996, 0.97665575659748682, 0.97678363169752269, 0.97690890213449544, 0.97703170653158888, 0.97715217766970741, 0.97727044288808629, 0.97738662446238889, 0.97750083996796688, 0.9776132026357256, 0.97772382170727035, 0.97783280279431639, 0.97794024824493742, 0.97804625751700502, 0.97815092755664912, 0.9782543531772393, 0.97835662743262508, 0.97845784197685504, 0.97855808740231209, 0.97865745354755695, 0.97875602976717369, 0.97885390515667281, 0.97895116872698307, 0.97904790952470111, 0.97914421669594753, 0.97924017949340969, 0.97933588722766496, 0.97943142916489589, 0.97952689437441842, 0.97962237152963094, 0.97971794866657924, 0.97981371290398489, 0.97990975012864823, 0.98000614464941005, 0.98010297882253772, 0.98020033265071238, 0.98029828335732028, 0.98039690493741183, 0.98049626768616827, 0.9805964377058628, 0.98069747639213822, 0.98079943990093699, 0.98090237859769924, 0.98100633649133684, 0.98111135065633126, 0.98121745064756094, 0.98132465791371504, 0.9814329852166439, 0.98154243606579927, 0.98165300417829171, 0.98176467297713865, 0.98187741514164317, 0.98199119222529807, 0.98210595435787218, 0.98222164004878687, 0.98233817610919094, 0.98245547770950548, 0.98257344858769857, 0.98269198142163638, 0.98281095837511034, 0.98293025182377347, 0.98304972526164114, 0.98316923438379389, 0.98328862833456931, 0.98340775110480849, 0.98352644305513437, 0.98364454253751321, 0.98376188758200633, 0.98387831761255629, 0.98399367515335456, 0.98410780748657034, 0.98422056822347126, 0.98433181875321962, 0.98444142953784319, 0.98454928122694796, 0.98465526557202088, 0.98475928612691854, 0.98486125872826813, 0.98496111175656631, 0.9850587861851805, 0.98515423543072245, 0.98524742502282669, 0.98533833211559185, 0.98542694486548921, 0.98551326170203901, 0.98559729051808975, 0.98567904780546234, 0.98575855776085874, 0.98583585138377949, 0.98591096558626379, 0.98598394233059561, 0.98605482780829279, 0.98612367167012538, 0.9861905263143651, 0.9862554462372024, 0.98631848744719386, 0.98637970694335364, 0.98643916225483497, 0.98649691103895787, 0.98655301073320478, 0.98660751825644044, 0.98666048975406639, 0.98671198038194607, 0.98676204412391155, 0.98681073363811789, 0.98685810012779507]} """ # Plot the MSE values
check_empty_rows_columns(M, fraction) # We now run the VB algorithm on each of the M's for each fraction. all_performances = {metric: [] for metric in metrics} average_performances = {metric: [] for metric in metrics} # averaged over repeats for (fraction, Ms, Ms_test) in zip(fractions_unknown, all_Ms, all_Ms_test): print "Trying fraction %s." % fraction # Run the algorithm <repeats> times and store all the performances for metric in metrics: all_performances[metric].append([]) for (repeat, M, M_test) in zip(range(0, repeats), Ms, Ms_test): print "Repeat %s of fraction %s." % (repeat + 1, fraction) nmf = NMF(R, M, K) nmf.initialise(init_UV, expo_prior) nmf.run(iterations) # Measure the performances performances = nmf.predict(M_test) for metric in metrics: # Add this metric's performance to the list of <repeat> performances for this fraction all_performances[metric][-1].append(performances[metric]) # Compute the average across attempts for metric in metrics: average_performances[metric].append( sum(all_performances[metric][-1]) / repeats)
# Load in data R = numpy.loadtxt(input_folder + "R.txt") M = numpy.ones((I, J)) # Run the VB algorithm, <repeats> times times_repeats = [] performances_repeats = [] for i in range(0, repeats): # Set all the seeds numpy.random.seed(0) random.seed(0) scipy.random.seed(0) # Run the classifier nmf = NMF(R, M, K) nmf.initialise(init_UV, expo_prior) nmf.run(iterations) # Extract the performances and timestamps across all iterations times_repeats.append(nmf.all_times) performances_repeats.append(nmf.all_performances) # Check whether seed worked: all performances should be the same assert all( [numpy.array_equal(performances, performances_repeats[0]) for performances in performances_repeats] ), "Seed went wrong - performances not the same across repeats!" # Print out the performances, and the average times all_times_average = list(numpy.average(times_repeats, axis=0)) all_performances = performances_repeats[0]
def test_update(): I, J = 2, 4 R = [[1, 2, 0, 4], [5, 0, 7, 0]] M = [[1, 1, 0, 1], [1, 0, 1, 0]] K = 2 U = numpy.array([[1, 2], [3, 4]], dtype='f') #2x2 V = numpy.array([[5, 6], [7, 8], [9, 10], [11, 12]], dtype='f') #4x2 new_U = [[ 1 * (1 * 5 / 17.0 + 2 * 7 / 23.0 + 4 * 11 / 35.0) / float(5 + 7 + 11), 2 * (1 * 6 / 17.0 + 2 * 8 / 23.0 + 4 * 12 / 35.0) / float(6 + 8 + 12) ], [ 3 * (5 * 5 / 39.0 + 7 * 9 / 67.0) / float(5 + 9), 4 * (5 * 6 / 39.0 + 7 * 10 / 67.0) / float(6 + 10) ]] new_V = [[ 5 * (1 * 1 / 17.0 + 3 * 5 / 39.0) / float(1 + 3), 6 * (2 * 1 / 17.0 + 4 * 5 / 39.0) / float(2 + 4) ], [7 * (1 * 2 / 23.0) / float(1), 8 * (2 * 2 / 23.0) / float(2)], [9 * (3 * 7 / 67.0) / float(3), 10 * (4 * 7 / 67.0) / float(4)], [11 * (1 * 4 / 35.0) / float(1), 12 * (2 * 4 / 35.0) / float(2)]] nmf = NMF(R, M, K) def reset(): nmf.U = numpy.copy(U) nmf.V = numpy.copy(V) for k in range(0, K): reset() nmf.update_U(k) for i in range(0, I): assert abs(new_U[i][k] - nmf.U[i, k]) < 0.00001 for k in range(0, K): reset() nmf.update_V(k) for j in range(0, J): assert abs(new_V[j][k] - nmf.V[j, k]) < 0.00001 # Also if I = J I, J, K = 2, 2, 3 R = [[1, 2], [3, 4]] M = [[1, 1], [0, 1]] U = numpy.array([[1, 2, 3], [4, 5, 6]], dtype='f') #2x3 V = numpy.array([[7, 8, 9], [10, 11, 12]], dtype='f') #2x3 R_pred = numpy.array([[50, 68], [122, 167]], dtype='f') #2x2 nmf = NMF(R, M, K) def reset_2(): nmf.U = numpy.copy(U) nmf.V = numpy.copy(V) for k in range(0, K): reset_2() nmf.update_U(k) for i in range(0, I): new_Uik = U[i][k] * sum( [V[j][k] * R[i][j] / R_pred[i,j] for j in range(0,J) if M[i][j] ]) \ / sum( [V[j][k] for j in range(0,J) if M[i][j] ]) assert abs(new_Uik - nmf.U[i, k]) < 0.00001 for k in range(0, K): reset_2() nmf.update_V(k) for j in range(0, J): new_Vjk = V[j][k] * sum( [U[i][k] * R[i][j] / R_pred[i,j] for i in range(0,I) if M[i][j] ]) \ / sum( [U[i][k] for i in range(0,I) if M[i][j] ]) assert abs(new_Vjk - nmf.V[j, k]) < 0.00001
def test_update(): I,J = 2,4 R = [[1,2,0,4],[5,0,7,0]] M = [[1,1,0,1],[1,0,1,0]] K = 2 U = numpy.array([[1,2],[3,4]],dtype='f') #2x2 V = numpy.array([[5,6],[7,8],[9,10],[11,12]],dtype='f') #4x2 new_U = [[ 1 * ( 1*5/17.0 + 2*7/23.0 + 4*11/35.0 ) / float( 5+7+11 ), 2 * ( 1*6/17.0 + 2*8/23.0 + 4*12/35.0 ) / float( 6+8+12 ) ],[ 3 * ( 5*5/39.0 + 7*9/67.0 ) / float( 5+9 ), 4 * ( 5*6/39.0 + 7*10/67.0 ) / float( 6+10 ) ]] new_V = [[ 5 * ( 1*1/17.0 + 3*5/39.0 ) / float( 1+3 ), 6 * ( 2*1/17.0 + 4*5/39.0 ) / float( 2+4 ) ],[ 7 * ( 1*2/23.0 ) / float( 1 ), 8 * ( 2*2/23.0 ) / float( 2 ) ],[ 9 * ( 3*7/67.0 ) / float( 3 ), 10 * ( 4*7/67.0 ) / float( 4 ) ],[ 11 * ( 1*4/35.0 ) / float( 1 ), 12 * ( 2*4/35.0 ) / float( 2 ) ]] nmf = NMF(R,M,K) def reset(): nmf.U = numpy.copy(U) nmf.V = numpy.copy(V) for k in range(0,K): reset() nmf.update_U(k) for i in range(0,I): assert abs(new_U[i][k] - nmf.U[i,k]) < 0.00001 for k in range(0,K): reset() nmf.update_V(k) for j in range(0,J): assert abs(new_V[j][k] - nmf.V[j,k]) < 0.00001 # Also if I = J I,J,K = 2,2,3 R = [[1,2],[3,4]] M = [[1,1],[0,1]] U = numpy.array([[1,2,3],[4,5,6]],dtype='f') #2x3 V = numpy.array([[7,8,9],[10,11,12]],dtype='f') #2x3 R_pred = numpy.array([[50,68],[122,167]],dtype='f') #2x2 nmf = NMF(R,M,K) def reset_2(): nmf.U = numpy.copy(U) nmf.V = numpy.copy(V) for k in range(0,K): reset_2() nmf.update_U(k) for i in range(0,I): new_Uik = U[i][k] * sum( [V[j][k] * R[i][j] / R_pred[i,j] for j in range(0,J) if M[i][j] ]) \ / sum( [V[j][k] for j in range(0,J) if M[i][j] ]) assert abs(new_Uik - nmf.U[i,k]) < 0.00001 for k in range(0,K): reset_2() nmf.update_V(k) for j in range(0,J): new_Vjk = V[j][k] * sum( [U[i][k] * R[i][j] / R_pred[i,j] for i in range(0,I) if M[i][j] ]) \ / sum( [U[i][k] for i in range(0,I) if M[i][j] ]) assert abs(new_Vjk - nmf.V[j,k]) < 0.00001
# Load in data R = numpy.loadtxt(input_folder + "R.txt") M = numpy.ones((I, J)) # Run the VB algorithm, <repeats> times times_repeats = [] performances_repeats = [] for i in range(0, repeats): # Set all the seeds numpy.random.seed(0) random.seed(0) scipy.random.seed(0) # Run the classifier nmf = NMF(R, M, K) nmf.initialise(init_UV, expo_prior) nmf.run(iterations) # Extract the performances and timestamps across all iterations times_repeats.append(nmf.all_times) performances_repeats.append(nmf.all_performances) # Check whether seed worked: all performances should be the same assert all([numpy.array_equal(performances, performances_repeats[0]) for performances in performances_repeats]), \ "Seed went wrong - performances not the same across repeats!" # Print out the performances, and the average times all_times_average = list(numpy.average(times_repeats, axis=0)) all_performances = performances_repeats[0] print "np_all_times_average = %s" % all_times_average