Beispiel #1
0
def test_predict():
    (I, J, K) = (5, 3, 2)
    R = numpy.array(
        [[1, 2, 3], [4, 5, 6], [7, 8, 9], [10, 11, 12], [13, 14, 15]],
        dtype=float)
    M = numpy.ones((I, J))
    K = 3

    U = numpy.array([[125., 126.], [126., 126.], [126., 126.], [126., 126.],
                     [126., 126.]])
    V = numpy.array([[84., 84.], [84., 84.], [84., 84.]])

    M_test = numpy.array([[0, 0, 1], [0, 1, 0], [0, 0, 0], [1, 1, 0],
                          [0, 0,
                           0]])  #R->3,5,10,11, R_pred->21084,21168,21168,21168
    MSE = (444408561. + 447872569. + 447660964. + 447618649) / 4.
    R2 = 1. - (444408561. + 447872569. + 447660964. + 447618649) / (
        4.25**2 + 2.25**2 + 2.75**2 + 3.75**2)  #mean=7.25
    Rp = 357. / (
        math.sqrt(44.75) * math.sqrt(5292.)
    )  #mean=7.25,var=44.75, mean_pred=21147,var_pred=5292, corr=(-4.25*-63 + -2.25*21 + 2.75*21 + 3.75*21)

    nmf = NMF(R, M, K)
    nmf.U = U
    nmf.V = V
    performances = nmf.predict(M_test)

    assert performances['MSE'] == MSE
    assert performances['R^2'] == R2
    assert performances['Rp'] == Rp
Beispiel #2
0
def test_run():
    # Data generated from W = [[1,2],[3,4]], H = [[4,3],[2,1]]
    R = [[8, 5], [20, 13]]
    M = [[1, 1], [1, 0]]
    K = 2

    U = numpy.array([[10, 9], [8, 7]], dtype='f')  #2x2
    V = numpy.array([[6, 4], [5, 3]], dtype='f')  #2x2

    nmf = NMF(R, M, K)

    # Check we get an Exception if W, H are undefined
    with pytest.raises(AssertionError) as error:
        nmf.run(0)
    assert str(
        error.value
    ) == "U and V have not been initialised - please run NMF.initialise() first."

    # Then check for 1 iteration whether the updates work - heck just the first entry of U
    nmf.U = U
    nmf.V = V
    nmf.run(1)

    U_00 = 10 * (6 * 8 / 96.0 + 5 * 5 / 77.0) / (5.0 + 6.0)  #0.74970484061
    assert abs(U_00 - nmf.U[0][0]) < 0.000001
Beispiel #3
0
def test_initialisation():
    I, J = 2, 3
    R = numpy.ones((I, J))
    M = numpy.ones((I, J))
    K = 4

    # Init ones
    init_UV = 'ones'
    nmf = NMF(R, M, K)
    nmf.initialise(init_UV)

    assert numpy.array_equal(numpy.ones((2, 4)), nmf.U)
    assert numpy.array_equal(numpy.ones((3, 4)), nmf.V)

    # Init random
    init_UV = 'random'
    nmf = NMF(R, M, K)
    nmf.initialise(init_UV)

    for (i, k) in itertools.product(range(0, I), range(0, K)):
        assert nmf.U[i, k] > 0 and nmf.U[i, k] < 1
    for (j, k) in itertools.product(range(0, J), range(0, K)):
        assert nmf.V[j, k] > 0 and nmf.V[j, k] < 1
Beispiel #4
0
def test_compute_statistics():
    R = numpy.array([[1, 2], [3, 4]], dtype=float)
    M = numpy.array([[1, 1], [0, 1]])
    (I, J, K) = 2, 2, 3

    nmf = NMF(R, M, K)

    R_pred = numpy.array([[500, 550], [1220, 1342]], dtype=float)
    M_pred = numpy.array([[0, 0], [1, 1]])

    MSE_pred = (1217**2 + 1338**2) / 2.0
    R2_pred = 1. - (1217**2 + 1338**2) / (0.5**2 + 0.5**2)  #mean=3.5
    Rp_pred = 61. / (math.sqrt(.5) * math.sqrt(7442.)
                     )  #mean=3.5,var=0.5,mean_pred=1281,var_pred=7442,cov=61

    assert MSE_pred == nmf.compute_MSE(M_pred, R, R_pred)
    assert R2_pred == nmf.compute_R2(M_pred, R, R_pred)
    assert Rp_pred == nmf.compute_Rp(M_pred, R, R_pred)
Beispiel #5
0
def test_init():
    # Test getting an exception when R and M are different sizes, and when R is not a 2D array
    R1 = numpy.ones(3)
    M = numpy.ones((2, 3))
    K = 0
    with pytest.raises(AssertionError) as error:
        NMF(R1, M, K)
    assert str(
        error.value
    ) == "Input matrix R is not a two-dimensional array, but instead 1-dimensional."

    R2 = numpy.ones((4, 3, 2))
    with pytest.raises(AssertionError) as error:
        NMF(R2, M, K)
    assert str(
        error.value
    ) == "Input matrix R is not a two-dimensional array, but instead 3-dimensional."

    R3 = numpy.ones((3, 2))
    with pytest.raises(AssertionError) as error:
        NMF(R3, M, K)
    assert str(
        error.value
    ) == "Input matrix R is not of the same size as the indicator matrix M: (3, 2) and (2, 3) respectively."

    # Test getting an exception if a row or column is entirely unknown
    R = numpy.ones((2, 3))
    M1 = [[1, 1, 1], [0, 0, 0]]
    M2 = [[1, 1, 0], [1, 0, 0]]

    with pytest.raises(AssertionError) as error:
        NMF(R, M1, K)
    assert str(error.value) == "Fully unobserved row in R, row 1."
    with pytest.raises(AssertionError) as error:
        NMF(R, M2, K)
    assert str(error.value) == "Fully unobserved column in R, column 2."

    # Test whether we made a copy of R with 1's at unknown values
    I, J = 2, 4
    R = [[1, 2, 0, 4], [5, 0, 7, 0]]
    M = [[1, 1, 0, 1], [1, 0, 1, 0]]
    K = 2
    R_excl_unknown = [[1, 2, 1, 4], [5, 1, 7, 1]]

    nmf = NMF(R, M, K)
    assert numpy.array_equal(R, nmf.R)
    assert numpy.array_equal(M, nmf.M)
    assert nmf.I == I
    assert nmf.J == J
    assert nmf.K == K
    assert numpy.array_equal(R_excl_unknown, nmf.R_excl_unknown)
Beispiel #6
0
def test_compute_I_div():
    R = [[1, 2, 0, 4], [5, 0, 7, 0]]
    M = [[1, 1, 0, 1], [1, 0, 1, 0]]
    K = 2

    U = numpy.array([[1, 2], [3, 4]], dtype='f')  #2x2
    V = numpy.array([[5, 7, 9, 11], [6, 8, 10, 12]], dtype='f').T  #4x2
    #R_pred = [[17,23,29,35],[39,53,67,81]]

    expected_I_div = sum([
        1.0 * math.log(1.0 / 17.0) - 1.0 + 17.0,
        2.0 * math.log(2.0 / 23.0) - 2.0 + 23.0,
        4.0 * math.log(4.0 / 35.0) - 4.0 + 35.0,
        5.0 * math.log(5.0 / 39.0) - 5.0 + 39.0,
        7.0 * math.log(7.0 / 67.0) - 7.0 + 67.0
    ])

    nmf = NMF(R, M, K)
    nmf.U = U
    nmf.V = V

    I_div = nmf.compute_I_div()
    assert I_div == expected_I_div
Beispiel #7
0
        check_empty_rows_columns(M, fraction)

# We now run the VB algorithm on each of the M's for each fraction.
all_performances = {metric: [] for metric in metrics}
average_performances = {metric: []
                        for metric in metrics}  # averaged over repeats
for (fraction, Ms, Ms_test) in zip(fractions_unknown, all_Ms, all_Ms_test):
    print "Trying fraction %s." % fraction

    # Run the algorithm <repeats> times and store all the performances
    for metric in metrics:
        all_performances[metric].append([])
    for (repeat, M, M_test) in zip(range(0, repeats), Ms, Ms_test):
        print "Repeat %s of fraction %s." % (repeat + 1, fraction)

        nmf = NMF(R, M, K)
        nmf.initialise(init_UV, expo_prior)
        nmf.run(iterations)

        # Measure the performances
        performances = nmf.predict(M_test)
        for metric in metrics:
            # Add this metric's performance to the list of <repeat> performances for this fraction
            all_performances[metric][-1].append(performances[metric])

    # Compute the average across attempts
    for metric in metrics:
        average_performances[metric].append(
            sum(all_performances[metric][-1]) / repeats)

Beispiel #8
0
from BNMTF.drug_sensitivity.load_data import load_Sanger

import matplotlib.pyplot as plt

##########

standardised = False  #standardised Sanger or unstandardised
no_folds = 5

iterations = 10000
I, J, K = 622, 139, 10

init_UV = 'exponential'
expo_prior = 1 / 10.

# Load in data
(_, X_min, M, _, _, _, _) = load_Sanger(standardised=standardised)

# Run the algorithm
nmf = NMF(X_min, M, K)
nmf.initialise(init_UV, expo_prior)
nmf.run(iterations)

# Print the performances across iterations (MSE)
print "all_performances = %s" % nmf.all_performances['MSE']
'''
all_performances = [174.63390857784989, 46.259552868280458, 34.97217434170355, 29.033239356587966, 24.840900151718895, 21.687298364331998, 19.235352907182676, 17.279545376055029, 15.686231064488448, 14.36522612054398, 13.253714545011841, 12.306693098578767, 11.491112174850791, 10.782169804720743, 10.160897447825528, 9.6125441849530517, 9.1254685730168017, 8.6903618584485045, 8.2996926785838649, 7.9473030293035904, 7.6281095696852965, 7.337879585172586, 7.0730607216384342, 6.8306500183440901, 6.6080920512923598, 6.4031989080119267, 6.2140867226306655, 6.0391249060348837, 5.8768952037977549, 5.7261584317561409, 5.5858272606520671, 5.4549438046397754, 5.3326610532021634, 5.2182273995036121, 5.1109736797361194, 5.0103022612418542, 4.9156778119761171, 4.8266194573130417, 4.7426940875109009, 4.6635106241813284, 4.5887150897178559, 4.5179863519609711, 4.4510324390461768, 4.3875873376159067, 4.3274082023268345, 4.2702729165708533, 4.2159779541206426, 4.1643364994447625, 4.1151767910560908, 4.0683406577326, 4.02368222200431, 3.9810667490880114, 3.9403696226323017, 3.9014754312975652, 3.8642771524448016, 3.8286754211045557, 3.7945778740079037, 3.7618985598308976, 3.730557407970025, 3.7004797491652344, 3.6715958821420052, 3.6438406811796864, 3.6171532401457136, 3.5914765490835876, 3.5667571999134742, 3.5429451182165481, 3.5199933184288219, 3.4978576800821366, 3.4764967429989428, 3.4558715195858607, 3.435945322574951, 3.4166836067478124, 3.3980538233314412, 3.3800252858991295, 3.362569046731545, 3.3456577827016551, 3.329265689844453, 3.3133683858564371, 3.2979428198462681, 3.2829671887237022, 3.2684208596751332, 3.2542842982244746, 3.2405390014273614, 3.2271674357876363, 3.2141529795227877, 3.201479868838224, 3.1891331479006642, 3.1770986222282169, 3.1653628152375308, 3.1539129277120681, 3.1427367999733602, 3.1318228765561682, 3.1211601732035024, 3.1107382460125659, 3.1005471625751193, 3.090577474968812, 3.0808201944656086, 3.0712667678334657, 3.0619090551188579, 3.0527393088009953, 3.0437501542226517, 3.0349345712022902, 3.0262858767471021, 3.0177977087825658, 3.0094640108296145, 3.0012790175575108, 2.9932372411508532, 2.9853334584296864, 2.9775626986691277, 2.9699202320654252, 2.9624015588022763, 2.9550023986718745, 2.9477186812104761, 2.9405465363094172, 2.9334822852674036, 2.9265224322504406, 2.9196636561300409, 2.9129028026720594, 2.9062368770503682, 2.8996630366629157, 2.8931785842281399, 2.886780961143379, 2.8804677410876285, 2.8742366238526005, 2.8680854293887972, 2.8620120920533489, 2.8560146550492824, 2.850091265045879, 2.8442401669719448, 2.8384596989742996, 2.8327482875359071, 2.8271044427470562, 2.8215267537263164, 2.8160138841868094, 2.8105645681447866, 2.805177605768375, 2.7998518593643236, 2.7945862495013234, 2.7893797512685388, 2.7842313906687362, 2.7791402411456314, 2.7741054202437558, 2.7691260864022533, 2.7642014358801075, 2.7593306998144453, 2.7545131414096558, 2.7497480532577394, 2.7450347547891703, 2.7403725898519999, 2.7357609244204415, 2.7311991444288415, 2.7266866537318237, 2.7222228721870358, 2.7178072338593888, 2.7134391853441659, 2.709118184206365, 2.7048436975336054, 2.7006152005992208, 2.6964321756332765, 2.6922941106966096, 2.6882004986557861, 2.6841508362540587, 2.6801446232755173, 2.6761813617976347, 2.6722605555283288, 2.6683817092243296, 2.6645443281849386, 2.6607479178191058, 2.6569919832798488, 2.6532760291633113, 2.6495995592674366, 2.6459620764062537, 2.6423630822771131, 2.6388020773746974, 2.6352785609506193, 2.6317920310128575, 2.6283419843626219, 2.6249279166648463, 2.621549322549245, 2.618205695738542, 2.6148965292010784, 2.6116213153245318, 2.6083795461085661, 2.6051707133735196, 2.6019943089827167, 2.5988498250756527, 2.5957367543111771, 2.5926545901169655, 2.5896028269448124, 2.5865809605285546, 2.5835884881448035, 2.580624908873232, 2.57768972385557, 2.5747824365529395, 2.5719025529985751, 2.569049582046965, 2.5662230356162157, 2.5634224289244245, 2.5606472807190546, 2.5578971134975585, 2.5551714537200407, 2.5524698320124379, 2.5497917833604711, 2.5471368472934475, 2.544504568057782, 2.5418944947806681, 2.5393061816225924, 2.5367391879193555, 2.5341930783135322, 2.5316674228754019, 2.52916179721274, 2.5266757825704387, 2.5242089659200238, 2.5217609400383503, 2.5193313035769638, 2.5169196611212858, 2.514525623240873, 2.5121488065301429, 2.5097888336401839, 2.5074453333029791, 2.505117940346179, 2.5028062957016481, 2.5005100464052621, 2.4982288455908122, 2.4959623524767256, 2.4937102323465732, 2.49147215652446, 2.4892478023439391, 2.4870368531129383, 2.484838998073152, 2.482653932355749, 2.4804813569329429, 2.4783209785659044, 2.4761725097495439, 2.4740356686543077, 2.4719101790654938, 2.4697957703201019, 2.4676921772418834, 2.4655991400743162, 2.463516404412716, 2.461443721134271, 2.4593808463278477, 2.4573275412227082, 2.4552835721163722, 2.4532487103025895, 2.4512227319989695, 2.4492054182738583, 2.4471965549745991, 2.445195932654773, 2.4432033465022975, 2.4412185962683939, 2.4392414861959999, 2.4372718249499901, 2.4353094255473025, 2.4333541052881769, 2.4314056856880875, 2.4294639924107204, 2.4275288552017091, 2.4256001078231426, 2.4236775879895074, 2.4217611373043857, 2.4198506011977456, 2.4179458288651734, 2.4160466732070218, 2.414152990769491, 2.4122646416860576, 2.4103814896203883, 2.4085034017096802, 2.4066302485095665, 2.4047619039394936, 2.4028982452292533, 2.4010391528661987, 2.3991845105437495, 2.397334205110119, 2.3954881265184409, 2.3936461677775234, 2.3918082249029444, 2.3899741968697783, 2.3881439855649775, 2.3863174957413928, 2.3844946349716771, 2.3826753136034426, 2.3808594447147478, 2.3790469440701285, 2.377237730077213, 2.3754317237439708, 2.373628848636669, 2.3718290308377443, 2.3700321989048074, 2.368238283829637, 2.3664472189980086, 2.3646589401496922, 2.3628733853389834, 2.3610904948956395, 2.3593102113860707, 2.3575324795753834, 2.3557572463891554, 2.3539844608759863, 2.3522140741705111, 2.3504460394564033, 2.3486803119302353, 2.3469168487649905, 2.3451556090746055, 2.3433965538786712, 2.3416396460671627, 2.3398848503660927, 2.3381321333027172, 2.3363814631720552, 2.3346328100029767, 2.3328861455250727, 2.3311414431354445, 2.3293986778664304, 2.3276578263533549, 2.3259188668025796, 2.324181778959967, 2.3224465440797259, 2.3207131448939218, 2.3189815655816077, 2.3172517917391526, 2.3155238103502502, 2.3137976097569015, 2.3120731796302492, 2.3103505109422042, 2.3086295959369738, 2.306910428103401, 2.305193002147528, 2.3034773139653741, 2.3017633606162931, 2.300051140296552, 2.2983406523133811, 2.2966318970592723, 2.2949248759867413, 2.2932195915834184, 2.2915160473471987, 2.2898142477624681, 2.2881141982758288, 2.2864159052726443, 2.2847193760538107, 2.2830246188128305, 2.281331642613277, 2.2796404573663258, 2.2779510738090027, 2.2762635034823213, 2.2745777587099729, 2.2728938525771007, 2.2712117989095622, 2.2695316122533806, 2.2678533078543146, 2.2661769016377811, 2.2645024101891589, 2.2628298507340556, 2.2611592411190822, 2.259490599792537, 2.2578239457856624, 2.2561592986937673, 2.2544966786577767, 2.2528361063459323, 2.2511776029354551, 2.2495211900948746, 2.2478668899659651, 2.2462147251461326, 2.2445647186711017, 2.2429168939975148, 2.2412712749856269, 2.2396278858825078, 2.2379867513049763, 2.2363478962229837, 2.2347113459429049, 2.2330771260910773, 2.2314452625975547, 2.2298157816797244, 2.2281887098263566, 2.2265640737815282, 2.2249419005287097, 2.2233222172753417, 2.2217050514368633, 2.2200904306214073, 2.2184783826144274, 2.2168689353634545, 2.2152621169630247, 2.2136579556395661, 2.2120564797366655, 2.2104577177003084, 2.2088616980642684, 2.2072684494356514, 2.2056780004806522, 2.2040903799101739, 2.202505616466107, 2.2009237389071239, 2.199344775995284, 2.1977687564820432, 2.1961957090952802, 2.1946256625258069, 2.1930586454142431, 2.1914946863383484, 2.1899338138000188, 2.1883760562129093, 2.1868214418900873, 2.1852699990318447, 2.1837217557135329, 2.1821767398741549, 2.1806349793044864, 2.179096501635851, 2.1775613343288658, 2.1760295046627784, 2.1745010397243734, 2.1729759663977348, 2.1714543113540192, 2.1699361010410572, 2.1684213616741168, 2.1669101192259492, 2.1654023994174412, 2.1638982277088172, 2.1623976292906657, 2.1609006290751185, 2.1594072516880272, 2.1579175214603921, 2.1564314624208993, 2.1549490982880206, 2.1534704524628743, 2.1519955480223132, 2.1505244077117696, 2.1490570539390244, 2.1475935087677867, 2.1461337939115923, 2.1446779307283097, 2.1432259402143403, 2.1417778429994674, 2.1403336593420357, 2.1388934091239569, 2.137457111846353, 2.1360247866255184, 2.1345964521885969, 2.1331721268700528, 2.1317518286083068, 2.1303355749423205, 2.1289233830087051, 2.1275152695391162, 2.1261112508574058, 2.1247113428777462, 2.1233155611024768, 2.1219239206202172, 2.1205364361042842, 2.1191531218115283, 2.1177739915812031, 2.1163990588337738, 2.1150283365704858, 2.1136618373729066, 2.1122995734024199, 2.1109415564002916, 2.1095877976878836, 2.1082383081666776, 2.1068930983189791, 2.1055521782086069, 2.104215557481715, 2.1028832453676336, 2.1015552506804593, 2.1002315818200525, 2.0989122467737964, 2.0975972531181042, 2.0962866080204932, 2.0949803182412849, 2.0936783901359437, 2.0923808296571869, 2.0910876423575404, 2.0897988333916859, 2.0885144075192619, 2.0872343691075566, 2.0859587221342784, 2.0846874701908358, 2.0834206164850664, 2.0821581638447593, 2.0809001147207051, 2.0796464711902565, 2.0783972349606779, 2.0771524073727368, 2.0759119894044211, 2.0746759816744729, 2.0734443844464159, 2.0722171976321202, 2.0709944207959521, 2.0697760531586376, 2.0685620936012614, 2.0673525406695399, 2.0661473925776663, 2.0649466472126665, 2.0637503021386232, 2.0625583546007795, 2.0613708015301158, 2.0601876395472734, 2.0590088649672227, 2.0578344738035055, 2.0566644617727943, 2.0554988242988759, 2.0543375565177948, 2.0531806532814372, 2.0520281091627344, 2.0508799184597755, 2.0497360752003284, 2.0485965731462819, 2.0474614057981926, 2.046330566399877, 2.045204047942573, 2.0440818431695904, 2.0429639445808569, 2.0418503444373362, 2.040741034765297, 2.0396360073608815, 2.0385352537946, 2.0374387654155885, 2.0363465333559936, 2.0352585485353858, 2.0341748016652366, 2.033095283252933, 2.0320199836063053, 2.0309488928378516, 2.0298820008689127, 2.0288192974341004, 2.0277607720851196, 2.0267064141953126, 2.0256562129634919, 2.0246101574183655, 2.0235682364221601, 2.0225304386751324, 2.0214967527192202, 2.0204671669422543, 2.0194416695816684, 2.018420248728555, 2.0174028923315053, 2.016389588200393, 2.0153803240101853, 2.01437508730467, 2.0133738655002662, 2.012376645889451, 2.011383415644767, 2.0103941618220298, 2.0094088713642666, 2.0084275311048456, 2.0074501277712633, 2.0064766479883756, 2.0055070782819655, 2.0045414050819059, 2.0035796147256875, 2.0026216934615038, 2.001667627451567, 2.0007174027753876, 1.9997710054326756, 1.998828421346754, 1.9978896363673677, 1.9969546362739561, 1.9960234067783678, 1.9950959335279939, 1.9941722021087012, 1.993252198047502, 1.9923359068155613, 1.9914233138308841, 1.990514404461128, 1.989609164026209, 1.9887075778010621, 1.9878096310182161, 1.9869153088703286, 1.9860245965127779, 1.9851374790662866, 1.9842539416190748, 1.9833739692296715, 1.9824975469289057, 1.9816246597225708, 1.9807552925936143, 1.979889430504334, 1.9790270583986662, 1.9781681612045641, 1.9773127238357491, 1.9764607311942712, 1.9756121681722147, 1.9747670196540601, 1.9739252705184474, 1.9730869056403169, 1.9722519098927909, 1.9714202681490007, 1.9705919652841311, 1.9697669861770521, 1.968945315712201, 1.9681269387814362, 1.9673118402856553, 1.9665000051365129, 1.9656914182581209, 1.964886064588597, 1.964083929081718, 1.9632849967085659, 1.9624892524588597, 1.9616966813426957, 1.9609072683917925, 1.960120998661065, 1.9593378572299431, 1.9585578292039005, 1.9577808997156307, 1.9570070539264508, 1.9562362770275357, 1.9554685542412162, 1.9547038708222464, 1.9539422120588898, 1.9531835632741696, 1.9524279098270167, 1.9516752371133697, 1.9509255305671975, 1.9501787756617268, 1.9494349579102401, 1.9486940628672755, 1.9479560761295476, 1.9472209833368828, 1.9464887701731741, 1.9457594223671832, 1.9450329256937011, 1.9443092659740899, 1.9435884290771843, 1.9428704009203104, 1.9421551674697071, 1.9414427147416615, 1.9407330288030078, 1.9400260957718818, 1.9393219018184857, 1.9386204331656631, 1.9379216760896014, 1.9372256169204329, 1.9365322420429003, 1.9358415378968645, 1.935153490977884, 1.9344680878377467, 1.9337853150851074, 1.9331051593858115, 1.9324276074634845, 1.9317526461000172, 1.9310802621359209, 1.9304104424708295, 1.9297431740638826, 1.929078443934084, 1.9284162391607491, 1.9277565468837976, 1.9270993543041106, 1.9264446486838784, 1.9257924173468035, 1.9251426476785434, 1.9244953271268723, 1.923850443201921, 1.9232079834765172, 1.9225679355863801, 1.9219302872302031, 1.9212950261700024, 1.9206621402312896, 1.9200316173031433, 1.9194034453384203, 1.9187776123538651, 1.9181541064302601, 1.9175329157125731, 1.9169140284099642, 1.9162974327959756, 1.9156831172084325, 1.9150710700497522, 1.9144612797867855, 1.9138537349509432, 1.9132484241382151, 1.9126453360091096, 1.9120444592888661, 1.9114457827671554, 1.9108492952982654, 1.9102549858009799, 1.9096628432586162, 1.9090728567189146, 1.9084850152939989, 1.9078993081602007, 1.9073157245582875, 1.9067342537929244, 1.9061548852330124, 1.9055776083112987, 1.9050024125243585, 1.9044292874324571, 1.903858222659506, 1.9032892078927253, 1.9027222328826583, 1.9021572874431234, 1.9015943614506432, 1.9010334448447503, 1.9004745276274404, 1.8999175998633129, 1.8993626516790829, 1.898809673263534, 1.8982586548673894, 1.8977095868029179, 1.897162459443839, 1.8966172632250913, 1.8960739886426035, 1.8955326262529819, 1.8949931666734807, 1.8944556005814546, 1.8939199187144429, 1.8933861118696405, 1.8928541709038076, 1.8923240867330202, 1.8917958503321817, 1.8912694527350515, 1.8907448850338409, 1.8902221383787936, 1.8897012039782053, 1.8891820730979316, 1.8886647370610863, 1.8881491872478187, 1.8876354150951535, 1.8871234120962772, 1.886613169800802, 1.8861046798139633, 1.8855979337966327, 1.8850929234648164, 1.8845896405895388, 1.8840880769962782, 1.883588224564918, 1.8830900752291979, 1.8825936209766099, 1.8820988538478194, 1.8816057659365713, 1.8811143493892837, 1.8806245964047448, 1.8801364992337355, 1.879650050178691, 1.8791652415933744, 1.8786820658826917, 1.8782005155021622, 1.8777205829576729, 1.877242260805176, 1.8767655416502762, 1.8762904181479383, 1.8758168830021711, 1.8753449289656616, 1.8748745488393823, 1.8744057354724435, 1.8739384817614817, 1.8734727806505813, 1.8730086251307474, 1.8725460082397289, 1.872084923061462, 1.8716253627259836, 1.8711673204089587, 1.8707107893312869, 1.8702557627589378, 1.8698022340024392, 1.8693501964166153, 1.8688996434003062, 1.8684505683959498, 1.8680029648892835, 1.867556826408971, 1.8671121465263365, 1.8666689188549896, 1.8662271370504553, 1.8657867948099862, 1.8653478858720189, 1.8649104040160267, 1.8644743430621296, 1.8640396968707298, 1.8636064593422417, 1.8631746244166894, 1.8627441860735017, 1.8623151383311267, 1.8618874752466501, 1.8614611909156007, 1.8610362794714947, 1.860612735085666, 1.8601905519667545, 1.8597697243606099, 1.8593502465498455, 1.8589321128535248, 1.8585153176269711, 1.8580998552612578, 1.8576857201831103, 1.8572729068543663, 1.8568614097719376, 1.856451223467356, 1.8560423425064176, 1.8556347614889928, 1.8552284750487305, 1.8548234778526402, 1.854419764600991, 1.8540173300267544, 1.8536161688955843, 1.853216276005349, 1.8528176461859192, 1.852420274298811, 1.8520241552369994, 1.8516292839245794, 1.8512356553164429, 1.8508432643980999, 1.8504521061853385, 1.8500621757238678, 1.8496734680891778, 1.8492859783862834, 1.8488997017492905, 1.8485146333412394, 1.8481307683538246, 1.8477481020071223, 1.8473666295492777, 1.8469863462564375, 1.8466072474321129, 1.8462293284072397, 1.8458525845399407, 1.8454770112149905, 1.845102603843777, 1.8447293578640376, 1.8443572687394638, 1.8439863319596981, 1.8436165430397768, 1.8432478975201718, 1.8428803909663498, 1.842514018968711, 1.8421487771420828, 1.8417846611257433, 1.8414216665830532, 1.8410597892012592, 1.8406990246911821, 1.8403393687871823, 1.8399808172465872, 1.8396233658499248, 1.8392670104002729, 1.8389117467231899, 1.8385575706665707, 1.8382044781003677, 1.8378524649163528, 1.8375015270278503, 1.8371516603696221, 1.8368028608976426, 1.8364551245887462, 1.8361084474406282, 1.8357628254714258, 1.8354182547196811, 1.8350747312440248, 1.8347322511229969, 1.8343908104549287, 1.8340504053575239, 1.8337110319679313, 1.8333726864423303, 1.8330353649558917, 1.8326990637024008, 1.832363778894212, 1.8320295067621148, 1.8316962435548862, 1.8313639855393622, 1.8310327290000987, 1.8307024702391979, 1.8303732055762578, 1.8300449313479825, 1.8297176439081835, 1.8293913396275057, 1.8290660148932236, 1.8287416661091533, 1.8284182896954773, 1.8280958820884241, 1.8277744397402829, 1.8274539591190755, 1.82713443670857, 1.8268158690079213, 1.8264982525316038, 1.8261815838092599, 1.8258658593855501, 1.8255510758198417, 1.8252372296863004, 1.8249243175734895, 1.8246123360843918, 1.8243012818361624, 1.8239911514599967, 1.8236819416010037, 1.823373648918009, 1.8230662700834388, 1.8227598017831774, 1.8224542407164293, 1.8221495835955714, 1.8218458271458891, 1.8215429681057125, 1.8212410032260149, 1.8209399292704336, 1.820639743014959, 1.8203404412480637, 1.8200420207703909, 1.8197444783945314, 1.819447810945235, 1.8191520152589189, 1.818857088183732, 1.8185630265793857, 1.8182698273170566, 1.8179774872792878, 1.8176860033597422, 1.8173953724631675, 1.8171055915054284, 1.816816657413084, 1.8165285671234588, 1.816241317584671, 1.8159549057551836, 1.8156693286039685, 1.8153845831102025, 1.8151006662634401, 1.8148175750631796, 1.814535306518948, 1.8142538576502951, 1.8139732254863599, 1.813693407066147, 1.8134143994382228, 1.813136199660599, 1.8128588048007903, 1.812582211935629, 1.8123064181510895, 1.8120314205424004, 1.8117572162138305, 1.8114838022785873, 1.8112111758587972, 1.8109393340853588, 1.810668274097984, 1.8103979930449565, 1.8101284880831325, 1.8098597563779542, 1.8095917951031251, 1.8093246014408277, 1.809058172581564, 1.8087925057238849, 1.8085275980745266, 1.8082634468483798, 1.808000049268268, 1.8077374025649386, 1.8074755039770705, 1.8072143507511305, 1.8069539401413139, 1.8066942694095627, 1.80643533582538, 1.8061771366659143, 1.8059196692158666, 1.8056629307673142, 1.8054069186198363, 1.8051516300803105, 1.8048970624630221, 1.8046432130894519, 1.8043900792883369, 1.8041376583955941, 1.8038859477543263, 1.8036349447145883, 1.8033846466336578, 1.8031350508757393, 1.8028861548119666, 1.8026379558204544, 1.8023904512862872, 1.8021436386012937, 1.8018975151641392, 1.801652078380408, 1.8014073256623091, 1.8011632544288367, 1.8009198621057316, 1.8006771461252682, 1.8004351039266051, 1.8001937329552888, 1.7999530306635805, 1.7997129945102603, 1.7994736219607628, 1.7992349104869316, 1.7989968575671196, 1.79875946068624, 1.7985227173356417, 1.7982866250131322, 1.7980511812228883, 1.7978163834755807, 1.7975822292883066, 1.7973487161843931, 1.7971158416937509, 1.7968836033525133, 1.7966519987031748, 1.7964210252946149, 1.7961906806820453, 1.795960962426943, 1.7957318680971437, 1.7955033952667907, 1.7952755415162802, 1.7950483044323655, 1.794821681608082, 1.7945956706427009, 1.7943702691417718, 1.7941454747171974, 1.7939212849871187, 1.7936976975758876, 1.7934747101142101, 1.7932523202390955, 1.7930305255936478, 1.7928093238275029, 1.7925887125963367, 1.7923686895622544, 1.7921492523935498, 1.7919303987649393, 1.7917121263572298, 1.7914944328576776, 1.7912773159597055, 1.7910607733632051, 1.7908448027741986, 1.790629401905157, 1.7904145684748474, 1.7902003002082472, 1.7899865948368112, 1.7897734500982498, 1.7895608637366791]
'''

# Plot the performances (MSE)
plt.plot(nmf.all_performances['MSE'])
Beispiel #9
0
def test_update():
    I, J = 2, 4
    R = [[1, 2, 0, 4], [5, 0, 7, 0]]
    M = [[1, 1, 0, 1], [1, 0, 1, 0]]
    K = 2

    U = numpy.array([[1, 2], [3, 4]], dtype='f')  #2x2
    V = numpy.array([[5, 6], [7, 8], [9, 10], [11, 12]], dtype='f')  #4x2

    new_U = [[
        1 * (1 * 5 / 17.0 + 2 * 7 / 23.0 + 4 * 11 / 35.0) / float(5 + 7 + 11),
        2 * (1 * 6 / 17.0 + 2 * 8 / 23.0 + 4 * 12 / 35.0) / float(6 + 8 + 12)
    ],
             [
                 3 * (5 * 5 / 39.0 + 7 * 9 / 67.0) / float(5 + 9),
                 4 * (5 * 6 / 39.0 + 7 * 10 / 67.0) / float(6 + 10)
             ]]

    new_V = [[
        5 * (1 * 1 / 17.0 + 3 * 5 / 39.0) / float(1 + 3),
        6 * (2 * 1 / 17.0 + 4 * 5 / 39.0) / float(2 + 4)
    ], [7 * (1 * 2 / 23.0) / float(1), 8 * (2 * 2 / 23.0) / float(2)],
             [9 * (3 * 7 / 67.0) / float(3), 10 * (4 * 7 / 67.0) / float(4)],
             [11 * (1 * 4 / 35.0) / float(1), 12 * (2 * 4 / 35.0) / float(2)]]

    nmf = NMF(R, M, K)

    def reset():
        nmf.U = numpy.copy(U)
        nmf.V = numpy.copy(V)

    for k in range(0, K):
        reset()
        nmf.update_U(k)

        for i in range(0, I):
            assert abs(new_U[i][k] - nmf.U[i, k]) < 0.00001

    for k in range(0, K):
        reset()
        nmf.update_V(k)
        for j in range(0, J):
            assert abs(new_V[j][k] - nmf.V[j, k]) < 0.00001

    # Also if I = J
    I, J, K = 2, 2, 3
    R = [[1, 2], [3, 4]]
    M = [[1, 1], [0, 1]]

    U = numpy.array([[1, 2, 3], [4, 5, 6]], dtype='f')  #2x3
    V = numpy.array([[7, 8, 9], [10, 11, 12]], dtype='f')  #2x3
    R_pred = numpy.array([[50, 68], [122, 167]], dtype='f')  #2x2

    nmf = NMF(R, M, K)

    def reset_2():
        nmf.U = numpy.copy(U)
        nmf.V = numpy.copy(V)

    for k in range(0, K):
        reset_2()
        nmf.update_U(k)
        for i in range(0, I):
            new_Uik = U[i][k] * sum( [V[j][k] * R[i][j] / R_pred[i,j] for j in range(0,J) if M[i][j] ]) \
                              / sum( [V[j][k] for j in range(0,J) if M[i][j] ])
            assert abs(new_Uik - nmf.U[i, k]) < 0.00001

    for k in range(0, K):
        reset_2()
        nmf.update_V(k)
        for j in range(0, J):
            new_Vjk = V[j][k] * sum( [U[i][k] * R[i][j] / R_pred[i,j] for i in range(0,I) if M[i][j] ]) \
                              / sum( [U[i][k] for i in range(0,I) if M[i][j] ])
            assert abs(new_Vjk - nmf.V[j, k]) < 0.00001