def test_fit_exception(self): with self.assertRaises(Exception) as cm: pmml = path.join(BASE_DIR, '../models/gb-xgboost-iris.pmml') clf = PMMLGradientBoostingClassifier(pmml) clf.fit(np.array([[]]), np.array([])) assert str(cm.exception) == 'Not supported.'
def setUp(self): df = pd.read_csv(path.join(BASE_DIR, '../models/categorical-test.csv')) cats = np.unique(df['age']) df['age'] = pd.Categorical(df['age'], categories=cats) Xte = df.iloc[:, 1:] yte = df.iloc[:, 0] self.test = (Xte, yte) pmml = path.join(BASE_DIR, '../models/gb-gbm-cat-pima.pmml') self.clf = PMMLGradientBoostingClassifier(pmml)
def test_sklearn2pmml_binary(self): # Export to PMML pipeline = PMMLPipeline([("classifier", self.ref)]) Xte, _, yte, _ = self.test pipeline.fit(Xte, yte) sklearn2pmml(pipeline, "gb-sklearn2pmml.pmml", with_repr=True) try: # Import PMML model = PMMLGradientBoostingClassifier(pmml='gb-sklearn2pmml.pmml') # Verify classification assert np.array_equal(self.ref.score(Xte, yte), model.score(Xte, yte)) assert np.allclose(self.ref.predict_proba(Xte), model.predict_proba(Xte)) finally: remove("gb-sklearn2pmml.pmml")
def test_invalid_model(self): with self.assertRaises(Exception) as cm: PMMLGradientBoostingClassifier(pmml=StringIO(""" <PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3"> <DataDictionary> <DataField name="Class" optype="categorical" dataType="string"> <Value value="setosa"/> <Value value="versicolor"/> <Value value="virginica"/> </DataField> </DataDictionary> <MiningSchema> <MiningField name="Class" usageType="target"/> </MiningSchema> </PMML> """)) assert str(cm.exception) == 'PMML model does not contain MiningModel.'
def test_non_voting_ensemble(self): with self.assertRaises(Exception) as cm: PMMLGradientBoostingClassifier(pmml=StringIO(""" <PMML xmlns="http://www.dmg.org/PMML-4_3" version="4.3"> <DataDictionary> <DataField name="Class" optype="categorical" dataType="string"> <Value value="setosa"/> <Value value="versicolor"/> <Value value="virginica"/> </DataField> </DataDictionary> <MiningModel> <MiningSchema> <MiningField name="Class" usageType="target"/> </MiningSchema> <Segmentation multipleModelMethod="mean" /> </MiningModel> </PMML> """)) assert str( cm.exception) == 'PMML model ensemble should use modelChain.'
def test_more_tags(self): clf = PMMLGradientBoostingClassifier( path.join(BASE_DIR, '../models/gb-xgboost-iris.pmml')) assert clf._more_tags() == GradientBoostingClassifier()._more_tags()
class TestCategoricalPimaGradientBoostingIntegration(TestCase): def setUp(self): df = pd.read_csv(path.join(BASE_DIR, '../models/categorical-test.csv')) cats = np.unique(df['age']) df['age'] = pd.Categorical(df['age'], categories=cats) Xte = df.iloc[:, 1:] yte = df.iloc[:, 0] self.test = (Xte, yte) pmml = path.join(BASE_DIR, '../models/gb-gbm-cat-pima.pmml') self.clf = PMMLGradientBoostingClassifier(pmml) def test_predict_proba(self): Xte, yte = self.test ref = np.array([ [0.0450486288734726, 0.9549513711265273], [0.5885160147778145, 0.4114839852221857], [0.0009551290028415, 0.9990448709971584], [0.0002422991756940, 0.9997577008243060], [0.0007366887385038, 0.9992633112614963], [0.0004013672347579, 0.9995986327652422], [0.0690449928421312, 0.9309550071578688], [0.0016227716534959, 0.9983772283465041], [0.1293482327095098, 0.8706517672904902], [0.0000087074095761, 0.9999912925904240], [0.0000791716298258, 0.9999208283701742], [0.0024433990519477, 0.9975566009480523], [0.0201993339302258, 0.9798006660697742], [0.0134746789867266, 0.9865253210132735], [0.2373888296249749, 0.7626111703750250], [0.0397951119541024, 0.9602048880458975], [0.6048657697720786, 0.3951342302279213], [0.0783995167388520, 0.9216004832611480], [0.0770031400315019, 0.9229968599684981], [0.0039225633807996, 0.9960774366192003], [0.0870982756252495, 0.9129017243747506], [0.0518153619780974, 0.9481846380219027], [0.0529373167466456, 0.9470626832533544], [0.1465303709628746, 0.8534696290371255], [0.2948688765091694, 0.7051311234908306], [0.0093530273605733, 0.9906469726394268], [0.8000454863890519, 0.1999545136109481], [0.9840595476042330, 0.0159404523957671], [0.9994934828677855, 0.0005065171322145], [0.9332373260900556, 0.0667626739099444], [0.8301575628390390, 0.1698424371609611], [0.9994014211291834, 0.0005985788708167], [0.9964913762897571, 0.0035086237102429], [0.9877457479606355, 0.0122542520393643], [0.9866554945920915, 0.0133445054079084], [0.9933345920713418, 0.0066654079286583], [0.9987899464870794, 0.0012100535129206], [0.9815251490646326, 0.0184748509353674], [0.8197681630417475, 0.1802318369582524], [0.6525706661194860, 0.3474293338805139], [0.8825708419171311, 0.1174291580828689], [0.9840595476042330, 0.0159404523957671], [0.8106506748060572, 0.1893493251939427], [0.9990627397651777, 0.0009372602348224], [0.8763278695225744, 0.1236721304774257], [0.7824982271907117, 0.2175017728092884], [0.9992088489858751, 0.0007911510141249], [0.9994453225672126, 0.0005546774327873], [0.5301562607268657, 0.4698437392731343], [0.9877457479606355, 0.0122542520393643], [0.9698651982991124, 0.0301348017008876], [0.7011874523008232, 0.2988125476991768], ]) assert np.allclose(ref, self.clf.predict_proba(Xte)) def test_score(self): Xte, yte = self.test ref = 0.9615384615384616 assert ref == self.clf.score(Xte, yte)
def test_R_xgboost(self): pmml = path.join(BASE_DIR, '../models/gb-xgboost-iris.pmml') clf = PMMLGradientBoostingClassifier(pmml) # Verify classification Xte, yte, _, _ = self.test ref = 0.9933333333333333 assert ref == clf.score(Xte, yte) ref = [ [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9757426417447441, 0.0128682851648328, 0.0113890730904230], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9757426417447441, 0.0128682851648328, 0.0113890730904230], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9757426417447441, 0.0128682851648328, 0.0113890730904230], [0.9757426417447441, 0.0128682851648328, 0.0113890730904230], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9757426417447441, 0.0128682851648328, 0.0113890730904230], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9757426417447441, 0.0128682851648328, 0.0113890730904230], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9757426417447441, 0.0128682851648328, 0.0113890730904230], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9757426417447441, 0.0128682851648328, 0.0113890730904230], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.9760594257941966, 0.0128724629749537, 0.0110681112308498], [0.0106166153088319, 0.9785400205188246, 0.0108433641723436], [0.0106166153088319, 0.9785400205188246, 0.0108433641723436], [0.0101155963333861, 0.9323607906747136, 0.0575236129919005], [0.0119572590868193, 0.9754718678170027, 0.0125708730961781], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0159323759619904, 0.9288613767350993, 0.0552062473029103], [0.0124389433616677, 0.9744837805406320, 0.0130772760977003], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0140073718554280, 0.9712664360693611, 0.0147261920752109], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0141667403550577, 0.9709395207122262, 0.0148937389327162], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106166153088319, 0.9785400205188246, 0.0108433641723436], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0141667403550577, 0.9709395207122262, 0.0148937389327162], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0345623929310359, 0.4340977426739484, 0.5313398643950156], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0100985567108226, 0.9307902380901225, 0.0591112051990549], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0104458453096068, 0.9628000437291715, 0.0267541109612217], [0.0367540158081575, 0.7002838055528258, 0.2629621786390166], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0498135826560765, 0.4929875804501168, 0.4571988368938067], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0159323759619904, 0.9288613767350993, 0.0552062473029103], [0.0106166153088319, 0.9785400205188246, 0.0108433641723436], [0.0119572590868193, 0.9754718678170027, 0.0125708730961781], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0140073718554280, 0.9712664360693611, 0.0147261920752109], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0106132395880432, 0.9782288782391356, 0.0111578821728212], [0.0108567094134873, 0.0144905851493065, 0.9746527054372062], [0.0164145902620253, 0.0255525167427301, 0.9580328929952446], [0.0108613904938508, 0.0140656636934379, 0.9750729458127113], [0.0108777095343255, 0.0125843153938132, 0.9765379750718612], [0.0108777095343255, 0.0125843153938132, 0.9765379750718612], [0.0108613904938508, 0.0140656636934379, 0.9750729458127113], [0.0435806598526578, 0.2599986718108350, 0.6964206683365073], [0.0108613904938508, 0.0140656636934379, 0.9750729458127113], [0.0108613904938508, 0.0140656636934379, 0.9750729458127113], [0.0108567094134873, 0.0144905851493065, 0.9746527054372062], [0.0142345906453859, 0.0312354373543125, 0.9545299720003015], [0.0108777095343255, 0.0125843153938132, 0.9765379750718612], [0.0108613904938508, 0.0140656636934379, 0.9750729458127113], [0.0184930026637306, 0.0305865778733861, 0.9509204194628834], [0.0164145902620253, 0.0255525167427301, 0.9580328929952446], [0.0108567094134873, 0.0144905851493065, 0.9746527054372062], [0.0108777095343255, 0.0125843153938132, 0.9765379750718612], [0.0108567094134873, 0.0144905851493065, 0.9746527054372062], [0.0108613904938508, 0.0140656636934379, 0.9750729458127113], [0.0510645228555055, 0.2463746443944004, 0.7025608327500941], [0.0108567094134873, 0.0144905851493065, 0.9746527054372062], [0.0184930026637306, 0.0305865778733861, 0.9509204194628834], [0.0108613904938508, 0.0140656636934379, 0.9750729458127113], [0.0125007199528260, 0.0186054175395307, 0.9688938625076433], [0.0108567094134873, 0.0144905851493065, 0.9746527054372062], [0.0108567094134873, 0.0144905851493065, 0.9746527054372062], [0.0299366378213490, 0.0965402806010371, 0.8735230815776139], [0.0120440726761590, 0.0544554457865047, 0.9335004815373361], [0.0108777095343255, 0.0125843153938132, 0.9765379750718612], [0.0464351273705071, 0.2443049639688132, 0.7092599086606799], [0.0108613904938508, 0.0140656636934379, 0.9750729458127113], [0.0108567094134873, 0.0144905851493065, 0.9746527054372062], [0.0108777095343255, 0.0125843153938132, 0.9765379750718612], [0.0363615404352432, 0.1786334851333196, 0.7850049744314372], [0.0252911348271332, 0.0660520596210707, 0.9086568055517961], [0.0108613904938508, 0.0140656636934379, 0.9750729458127113], [0.0108567094134873, 0.0144905851493065, 0.9746527054372062], [0.0108777095343255, 0.0125843153938132, 0.9765379750718612], [0.0250153077815250, 0.2450614165877344, 0.7299232756307406], [0.0108613904938508, 0.0140656636934379, 0.9750729458127113], [0.0108613904938508, 0.0140656636934379, 0.9750729458127113], [0.0142478270786007, 0.0303346044639386, 0.9554175684574607], [0.0164145902620253, 0.0255525167427301, 0.9580328929952446], [0.0108567094134873, 0.0144905851493065, 0.9746527054372062], [0.0108567094134873, 0.0144905851493065, 0.9746527054372062], [0.0108613904938508, 0.0140656636934379, 0.9750729458127113], [0.0125007199528260, 0.0186054175395307, 0.9688938625076433], [0.0108777095343255, 0.0125843153938132, 0.9765379750718612], [0.0108567094134873, 0.0144905851493065, 0.9746527054372062], [0.0162658935674777, 0.0343798537311987, 0.9493542527013236], ] assert np.allclose(ref, clf.predict_proba(Xte))