def test_tree_gv_format(self): expected = u""" digraph { root [shape=record,label=< ROOT >]; struct1 [shape=record,label=< <FONT POINT-SIZE="11">`a03` ≤ 0.4404</FONT><br /><FONT POINT-SIZE="11">LABELS: 0:8, 1:9</FONT> >]; struct2 [shape=record,label=< <FONT POINT-SIZE="11">`a14` ≤ 0.261825</FONT><br /><FONT POINT-SIZE="11">LABELS: 0:3, 1:9</FONT> >]; struct3 [shape=record,style=filled,fillcolor=azure2,label=< <FONT POINT-SIZE="16">0</FONT><br /><FONT POINT-SIZE="11">`a26` ≤ -0.327175</FONT><br /><FONT POINT-SIZE="11">LABELS: 0:2</FONT> >]; struct2 -> struct3; struct4 [shape=record,style=filled,fillcolor=azure2,label=< <FONT POINT-SIZE="16">1</FONT><br /><FONT POINT-SIZE="11">`a26` > -0.327175</FONT><br /><FONT POINT-SIZE="11">LABELS: 0:1, 1:9</FONT> >]; struct2 -> struct4; struct1 -> struct2; struct5 [shape=record,style=filled,fillcolor=azure2,label=< <FONT POINT-SIZE="16">0</FONT><br /><FONT POINT-SIZE="11">`a14` > 0.261825</FONT><br /><FONT POINT-SIZE="11">LABELS: 0:5</FONT> >]; struct1 -> struct5; root -> struct1; struct6 [shape=record,label=< <FONT POINT-SIZE="11">`a03` > 0.4404</FONT><br /><FONT POINT-SIZE="11">LABELS: 0:5, 1:105</FONT> >]; struct7 [shape=record,style=filled,fillcolor=azure2,label=< <FONT POINT-SIZE="16">1</FONT><br /><FONT POINT-SIZE="11">`a24` ≤ 0.233095</FONT><br /><FONT POINT-SIZE="11">LABELS: 0:1, 1:89</FONT> >]; struct6 -> struct7; struct8 [shape=record,label=< <FONT POINT-SIZE="11">`a24` > 0.233095</FONT><br /><FONT POINT-SIZE="11">LABELS: 0:4, 1:16</FONT> >]; struct9 [shape=record,label=< <FONT POINT-SIZE="11">`a16` ≤ 0.424775</FONT><br /><FONT POINT-SIZE="11">LABELS: 0:4, 1:1</FONT> >]; struct10 [shape=record,style=filled,fillcolor=azure2,label=< <FONT POINT-SIZE="16">0</FONT><br /><FONT POINT-SIZE="11">`a20` ≤ -0.059785</FONT><br /><FONT POINT-SIZE="11">LABELS: 0:1, 1:1</FONT> >]; struct9 -> struct10; struct11 [shape=record,style=filled,fillcolor=azure2,label=< <FONT POINT-SIZE="16">0</FONT><br /><FONT POINT-SIZE="11">`a20` > -0.059785</FONT><br /><FONT POINT-SIZE="11">LABELS: 0:3</FONT> >]; struct9 -> struct11; struct8 -> struct9; struct12 [shape=record,style=filled,fillcolor=azure2,label=< <FONT POINT-SIZE="16">1</FONT><br /><FONT POINT-SIZE="11">`a16` > 0.424775</FONT><br /><FONT POINT-SIZE="11">LABELS: 1:15</FONT> >]; struct8 -> struct12; struct6 -> struct8; root -> struct6; } """.strip() ptree = PmmlTree(ET.fromstring(TREE_XML)) assert ptree._repr_gv_().strip() == expected
def test_regression_text_format(self): expected = """ Function: regression Target Field: y Normalization: softmax Target: no y = 125.56601826 - 28.6617384 * x1 - 20.42027426 * x2 """.strip() regr = PmmlRegression(ET.fromstring(REGRESSION_XML)) assert repr(regr).strip() == expected
def test_tree_text_format(self): expected = """ ROOT ├── WHEN `a03` ≤ 0.4404 (COUNTS: 0:8, 1:9) │ ├── WHEN `a14` ≤ 0.261825 (COUNTS: 0:3, 1:9) │ │ ├── SCORE = 0 WHEN `a26` ≤ -0.327175 (COUNTS: 0:2) │ │ └── SCORE = 1 WHEN `a26` > -0.327175 (COUNTS: 0:1, 1:9) │ └── SCORE = 0 WHEN `a14` > 0.261825 (COUNTS: 0:5) └── WHEN `a03` > 0.4404 (COUNTS: 0:5, 1:105) ├── SCORE = 1 WHEN `a24` ≤ 0.233095 (COUNTS: 0:1, 1:89) └── WHEN `a24` > 0.233095 (COUNTS: 0:4, 1:16) ├── WHEN `a16` ≤ 0.424775 (COUNTS: 0:4, 1:1) │ ├── SCORE = 0 WHEN `a20` ≤ -0.059785 (COUNTS: 0:1, 1:1) │ └── SCORE = 0 WHEN `a20` > -0.059785 (COUNTS: 0:3) └── SCORE = 1 WHEN `a16` > 0.424775 (COUNTS: 1:15) """.strip() from odps.ml.models.pmml import PmmlTree from odps.compat import ElementTree as ET ptree = PmmlTree(ET.fromstring(TREE_XML)) assert(repr(ptree).strip() == expected)
def testArrayParse(self): elem = ET.fromstring(r'<Array n="3" type="string">ab "a b" "with \"quotes\" "</Array>') assert parse_pmml_array(elem) == ['ab', 'a b', 'with "quotes" '] elem = ET.fromstring(r'<Array n="3" type="int">1 22 3</Array>') assert parse_pmml_array(elem) == [1, 22, 3]
def test_array_parse(self): elem = ET.fromstring(r'<Array n="3" type="string">ab "a b" "with \"quotes\" "</Array>') assert parse_pmml_array(elem) == ['ab', 'a b', 'with "quotes" '] elem = ET.fromstring(r'<Array n="3" type="int">1 22 3</Array>') assert parse_pmml_array(elem) == [1, 22, 3]