def compare_hypotheses(self, old: Hypothesis, new: SingleParameterHypothesis, measurements: Sequence[Measurement]): """ Compares the best with the new hypothesis and decides which one is a better fit for the data. If the new hypothesis is better than the best one it becomes the best hypothesis. The choice is made based on the RSS or SMAPE. """ if old == MAX_HYPOTHESIS: return True # get the compound terms of the new hypothesis compound_terms = new.function.compound_terms previous = numpy.seterr(divide='ignore', invalid='ignore') # for all compound terms check if they are smaller than minimum allowed contribution for term in compound_terms: # ignore this hypothesis, since one of the terms contributes less than epsilon to the function if term.coefficient == 0 or new.calc_term_contribution(term, measurements) < self.epsilon: return False numpy.seterr(**previous) # print smapes in debug mode logging.debug("next hypothesis SMAPE: " + str(new.SMAPE) + ' RSS:' + str(new.RSS)) logging.debug("best hypothesis SMAPE: " + str(old.SMAPE) + ' RSS:' + str(old.RSS)) if self.compare_with_RSS: return new.RSS < old.RSS return new.SMAPE < old.SMAPE
def build_hypotheses(self, measurements): """ Builds the next hypothesis that should be analysed based on the given compound term. """ hypotheses_building_blocks = self.get_matching_hypotheses(measurements) # search for the best hypothesis over all functions that can be build with the basic building blocks # using leave one out crossvalidation for i, compound_term in enumerate(hypotheses_building_blocks): # create next function that will be analyzed next_function = SingleParameterFunction(copy.copy(compound_term)) # create single parameter hypothesis from function yield SingleParameterHypothesis(next_function, self.use_median)
def test_compare(self): points = [4, 8, 16, 32, 64, 128] data = [ ((None, (12.279235119728051, 112.3997486813747)), [ 124.67898380110276, 124.67898380110276, 124.67898380110276, 124.67898380110276, 124.67898380110276, 124.67898380110276 ], (None, (124.679, -1.17549))), (((0, Fraction(1, 1)), (392.837968713381, 683.8645895889935)), [ 1760.5671478913678, 2444.4317374803613, 3128.296327069355, 3812.1609166583485, 4496.025506247342, 5179.890095836336 ], ((0.0, 1.0), (392.838, 683.865))), (((0, Fraction(2, 1)), (138.69179452369758, 112.44445041582443)), [ 588.4695961869953, 1150.6918482661176, 1937.8030011768885, 2949.803054919308, 4186.692009493378, 5648.469864899094 ], ((0.0, 2.0), (138.692, 112.444))), (((Fraction(1, 4), 0), (231.8031252715932, 757.5927278025262)), [ 1303.2010356851542, 1505.917143334448, 1746.9885808766455, 2033.672449625761, 2374.5989460987153, 2780.0311613973026 ], ((0.25, 0.0), (231.803, 757.593))), (((Fraction(1, 3), 0), (147.40207355905747, 740.6554582848072)), [ 1323.1193271863492, 1628.712990128672, 2013.7368787841829, 2498.836580813641, 3110.023906698286, 3880.071684009308 ], ((0.333333, 0.0), (147.402, 740.655))), (((Fraction(1, 4), Fraction(1, 1)), (662.1669933486077, 136.57938776640577)), [ 1048.4718383883378, 1351.2616987705137, 1754.802095479854, 2286.378790293861, 2979.9960635869884, 3877.9422853175015 ], ((0.25, 1.0), (662.167, 136.579))), (((Fraction(1, 3), Fraction(1, 1)), (535.6622118860412, 447.75148218635366)), [ 1957.1845595719178, 3222.171105004163, 5048.714352111772, 7643.273950315424, 11281.697784358528, 16331.344702676097 ], ((0.333333, 1.0), (535.662, 447.751))), (((Fraction(1, 4), Fraction(2, 1)), (412.5706706079675, 996.343695251814)), [ 6048.741737048133, 15493.363821169984, 32295.568918666017, 59655.521239685964, 101863.64986653095, 164625.65164339435 ], ((0.25, 2.0), (412.571, 996.344))), (((Fraction(1, 3), Fraction(2, 1)), (93.11229615417925, 20.367438006670188)), [ 222.43746622492063, 459.72618027424267, 914.2759402192239, 1709.6769220384463, 3026.0233691146855, 5122.739616052577 ], ((0.333333, 2.0), (93.1123, 20.3674))), (((Fraction(1, 2), 0), (939.8019758412179, 402.94640866510485)), [ 1745.6947931714276, 2079.5065279286637, 2551.5876105016373, 3219.2110800161095, 4163.373245162056, 5498.620184191001 ], ((0.5, 0.0), (939.802, 402.946))), (((Fraction(1, 2), Fraction(1, 1)), (198.49843369241415, 330.31007853365884)), [ 1519.7387478270496, 3001.272390797349, 5483.459690230956, 9541.078290708863, 16053.382203308038, 26357.722033338472 ], ((0.5, 1.0), (198.498, 330.31))), (((Fraction(1, 2), Fraction(2, 1)), (364.8953574839538, 955.112891429775)), [ 8005.798488922153, 24678.100241316602, 61492.120408989555, 135438.25582322088, 275437.4080892591, 529852.4683831728 ], ((0.5, 2.0), (364.895, 955.113))), (((Fraction(2, 3), 0), (210.3330694987003, 216.92681699057178)), [ 756.9543955249287, 1078.0403374609873, 1587.732499462487, 2396.8183736036135, 3681.1621413478483, 5719.930789353846 ], ((0.666667, 0.0), (210.333, 216.927))), (((Fraction(3, 4), 0), (584.9013580111865, 547.3819137326248)), [ 2133.1312104080216, 3188.7032237497583, 4963.956667872185, 7949.565182530901, 12970.740177185866, 21415.31628391976 ], ((0.75, 0.0), (584.901, 547.382))), (((Fraction(2, 3), Fraction(1, 1)), (953.7431095545323, 838.6830078923111)), [ 5180.440612885216, 11017.939204262264, 22254.963733492266, 43220.718142861355, 81467.31186721638, 150062.28747711863 ], ((0.666667, 1.0), (953.743, 838.683))), (((Fraction(3, 4), Fraction(1, 1)), (355.50475595529707, 203.8586065472728)), [ 1508.7031806978325, 3264.666020281983, 6878.980165468027, 14069.422473092825, 28032.266949776153, 54659.84835672009 ], ((1.0, 0.0), (48.1612, 429.067))), (((Fraction(4, 5), 0), (836.4136945625079, 988.9778707606993)), [ 3834.433979810851, 6056.270190754812, 9924.711720732794, 16660.0596267337, 28386.981453862616, 48804.738258536 ], ((0.8, 0.0), (836.414, 988.978))), (((Fraction(2, 3), Fraction(2, 1)), (30.370684174349353, 735.0460670350777)), [ 7439.17078417381, 26492.029097437142, 74706.39628779484, 185250.37318416082, 423416.90529637906, 914811.6843285251 ], ((1.0, 0.0), (-37722.0, 7374.79))), (((Fraction(3, 4), Fraction(2, 1)), (868.3557741916711, 651.1510359543278)), [ 8235.288783790882, 28745.079790529722, 84215.68837634563, 219888.5845432864, 531287.5324653349, 1215054.5573746935 ], ((1.0, 1.0), (-602.755, 1361.36))), (((Fraction(1, 1), 0), (218.35982887307853, 796.5944762009765)), [ 3404.7377336769846, 6591.11563848089, 12963.871448088703, 25709.383067304327, 51200.406305735574, 102182.45278259806 ], ((1.0, 0.0), (218.36, 796.594))), (((Fraction(1, 1), Fraction(1, 1)), (729.8185276288646, 193.81268721358396)), [ 2280.320025337536, 5381.323020754879, 13133.830509298237, 31739.8484818023, 75153.8904176451, 174385.9862710001 ], ((1.0, 1.0), (729.819, 193.813))), (((Fraction(1, 1), Fraction(2, 1)), (640.8857481060144, 219.18401331861853)), [ 4147.829961203911, 16422.13470704655, 56751.993157672354, 175988.09640300085, 505640.8524342031, 1375363.0172824815 ], ((1.0, 2.0), (640.886, 219.184))), (((Fraction(5, 4), 0), (41.41439439883205, 336.0107050284518)), [ 1942.1779790139603, 4562.217551923606, 10793.75695530929, 25614.93894716142, 60865.84910208294, 144707.1154351916 ], ((1.0, 1.0), (302.182, 160.594))), (((Fraction(5, 4), Fraction(1, 1)), (334.34344019665406, 396.1666489172391)), [ 4816.457423065935, 16324.82895624065, 51043.67450160326, 151094.0866783297, 430617.2857956476, 1194290.5953048149 ], ((1.0, 2.0), (352.638, 189.925))), (((Fraction(4, 3), 0), (646.7639733950962, 836.1733802023176)), [ 5956.133986838953, 14025.538056632176, 34359.162151911856, 85596.68418849679, 214707.14930518836, 540045.1348296632 ], ((1.33333, 0.0), (646.764, 836.173))), (((Fraction(4, 3), Fraction(1, 1)), (961.3235324936308, 976.4308028867101)), [ 13361.221801905767, 47830.002071055715, 158430.21598980148, 496957.254308979, 1500759.03676648, 4410090.312337113 ], ((1.0, 2.0), (-34653.4, 703.225))), (((Fraction(3, 2), 0), (993.5060588040174, 789.8477910359313)), [ 7312.288387091468, 18865.72139149913, 51543.76468510362, 143971.22872036492, 405395.57506920083, 1144815.2873512912 ], ((1.5, 0.0), (993.506, 789.848))), (((Fraction(3, 2), Fraction(1, 1)), (306.3138276450713, 176.2989470011036)), [ 3127.096979662729, 12273.883197935773, 45438.844259927595, 159873.90543152104, 541896.6790150353, 1787463.339791056 ], ((1.5, 1.0), (306.314, 176.299))), (((Fraction(3, 2), Fraction(2, 1)), (623.7521800036756, 545.819243769916)), [ 18089.967980640988, 111778.0688886881, 559542.6578003977, 2470719.679039657, 10061164.053347098, 38731727.88533937 ], ((1.5, 2.0), (623.752, 545.819))), (((Fraction(5, 3), 0), (674.515344060474, 93.52470814254)), [ 1617.1853318529588, 3667.3060046217547, 10176.033429851634, 30839.954953419994, 96443.81648202146, 304723.0940893776 ], ((1.66667, 0.0), (674.515, 93.5247))), (((Fraction(7, 4), 0), (192.79185213528302, 921.0172026961501)), [ 10612.912005989885, 35241.758587692566, 118082.9937972425, 396726.5846888438, 1333968.1715455244, 4486460.534003467 ], ((1.75, 0.0), (192.792, 921.017))), (((Fraction(2, 1), 0), (601.3899361738712, 695.3677746959734)), [ 11727.274331309445, 45104.92751671617, 178615.54025834304, 712657.9912248506, 2848827.7950908807, 11393507.010555001 ], ((2.0, 0.0), (601.39, 695.368))), (((Fraction(2, 1), Fraction(1, 1)), (95.64610607936808, 399.1728717576563)), [ 12869.17800232437, 76736.83748354937, 408848.66678591946, 2043860.7495052798, 9810168.14242224, 45780433.96224817 ], ((2.0, 1.0), (95.6461, 399.173))), (((Fraction(2, 1), Fraction(2, 1)), (910.0933475245264, 649.3303470713025)), [ 42467.23556008789, 374924.3732605948, 2660567.19495158, 16623766.97837287, 95748565.7510935, 521293702.00774235 ], ((2.0, 2.0), (910.093, 649.33))), (((Fraction(9, 4), 0), (991.6538373014305, 196.89404893784854)), [ 5446.857587036748, 22184.29382918959, 101801.40689347989, 480526.35622160026, 2282055.9737017835, 10851623.32968404 ], ((2.0, 1.0), (-5485.73, 94.5984))), (((Fraction(7, 3), 0), (454.80058623925424, 800.545149328823)), [ 20787.379981321064, 102924.57970032863, 516870.1273219162, 2603024.963156712, 13116586.527189683, 66101616.62275291 ], ((2.0, 2.0), (326893.0, 82.0924))), (((Fraction(5, 2), 0), (444.6771885244102, 788.914623707213)), [ 25689.945147155224, 143253.478519879, 808293.2518647105, 4570326.31979187, 25851599.066826478, 146236657.2404956 ], ((2.5, 0.0), (444.677, 788.915))), (((Fraction(5, 2), Fraction(1, 1)), (561.6799261002398, 118.03818227888634)), [ 8116.123591948965, 64663.26005666098, 484046.0745404187, 3419312.620222673, 23207812.621413387, 153160603.80521256 ], ((2.5, 1.0), (561.68, 118.038))), (((Fraction(5, 2), Fraction(2, 1)), (98.86684880610103, 982.9958000733947)), [ 125922.32925820061, 1601570.0898857696, 16105502.055251304, 142353096.47013444, 1159589128.4318285, 8928380108.544924 ], ((2.5, 2.0), (0.0, 982.996))), (((Fraction(8, 3), 0), (119.32012560773262, 194.108049319692)), [ 7945.26627894892, 49810.980751448864, 315641.6975316356, 2003561.5353809511, 12721184.440340932, 80773847.93606871 ], ((2.5, 1.0), (145065.0, 62.1811))), (((Fraction(11, 4), 0), (335.21389505653997, 713.4600389668701)), [ 32622.72952123845, 217538.8630750938, 1461501.3736992066, 9829850.300849823, 66125167.21631561, 444833408.7346114 ], ((3.0, 0.0), (2.6705e+06, 211.32))), (((Fraction(3, 1), 0), (854.0891091206599, 475.68703018220896)), [ 31298.059040782035, 244405.84856241164, 1949268.1647354485, 15588166.694119744, 124699354.92919411, 997588860.809789 ], ((3.0, 0.0), (854.089, 475.687))), (((Fraction(3, 1), Fraction(1, 1)), (498.14812816021515, 788.4374477528575)), [ 101418.14144052597, 1211538.0678765494, 12918257.292110976, 129178089.58795632, 1240105375.9704788, 11574312691.156733 ], ((3.0, 1.0), (498.148, 788.437))), ] modeler = RefiningModeler() modeler.use_crossvalidation = False modeler.compare_with_RSS = True for orig, values, (exponents, coeff) in data: if exponents: term = CompoundTerm.create(*exponents) term.coefficient = coeff[1] function = SingleParameterFunction(term) else: function = SingleParameterFunction() function.constant_coefficient = coeff[0] measurements = [ Measurement(Coordinate(p), None, None, v) for p, v in zip(points, values) ] models = modeler.model([measurements]) self.assertEqual(1, len(models)) hypothesis = SingleParameterHypothesis(function, False) hypothesis.compute_cost(measurements) self.assertApproxFunction(function, models[0].hypothesis.function)