def test_SimpleTree_to_string_reg_decimals(self): data = Table("housing") lrn = SimpleTreeReg(min_instances=1) reg = lrn(data) reg_str = reg.to_string() res = ' LSTAT (19.934: 430.0)' self.assertEqual(reg_str.split("\n")[3], res)
def test_SimpleTree_regression_tree(self): lrn = SimpleTreeReg(min_instances=5) clf = lrn(self.data_reg) self.assertEqual( clf.dumps_tree(clf.node), '{ 0 2 { 1 4 0.13895 { 1 4 -0.32607 { 2 4.60993 1.71141 } { 2 4.96454 3.56122 } } { 2 7.09220 -4.32343 } } { 1 4 -0.35941 { 0 0 { 1 5 -0.20027 { 2 3.54255 0.95095 } { 2 5.50000 -5.56049 } } { 2 7.62411 2.03615 } } { 1 5 0.40797 { 1 3 0.83459 { 2 3.71094 0.27028 } { 2 5.18490 3.70920 } } { 2 5.77083 5.93398 } } } }' )
def test_SimpleTree_to_string_cls_decimals(self): data = Table("voting") lrn = SimpleTreeReg(min_instances=1) cls = lrn(data) cls_str = cls.to_string() res = ' adoption-of-the-budget-resolution ([3.7, 249.7])' self.assertEqual(cls_str.split("\n")[3], res)
def test_SimpleTree_to_string_cls_decimals(self): data = Table(test_filename("datasets/lenses.tab")) lrn = SimpleTreeReg(min_instances=1) cls = lrn(data) cls_str = cls.to_string() res = ' astigmatic ([4.0, 3.0, 5.0])' self.assertEqual(cls_str.split("\n")[3], res)
def test_SimpleTree_to_string_regression(self): domain = Domain( [ DiscreteVariable(name="d1", values="ef"), ContinuousVariable(name="c1") ], ContinuousVariable(name="cls"), ) data = Table( domain, [ ["e", 1, 10], ["e", 1, 20], ["e", 2, 20], ["f", 2, 30], ["e", 3, 10], ["f", 3, 30], ], ) lrn = SimpleTreeReg(min_instances=1) reg = lrn(data) reg_str = reg.to_string() res = ("\n" "d1 (20.0: 6.0)\n" ": e\n" " c1 (15.0: 4.0)\n" " : <=2.5\n" " c1 (16.667: 3.0)\n" " : <=1.5 --> (15.0: 2.0)\n" " : >1.5 --> (20.0: 1.0)\n" " : >2.5 --> (10.0: 1.0)\n" ": f --> (30.0: 2.0)") self.assertEqual(reg_str, res)
def test_SimpleTree_to_string_regression(self): domain = Domain([ DiscreteVariable(name='d1', values='ef'), ContinuousVariable(name='c1') ], ContinuousVariable(name='cls')) data = Table(domain, [['e', 1, 10], ['e', 1, 20], ['e', 2, 20], ['f', 2, 30], ["e", 3, 10], ['f', 3, 30]]) lrn = SimpleTreeReg(min_instances=1) reg = lrn(data) reg_str = reg.to_string() res = '\nd1 (20.0: 6.0)\n: e\n c1 (15.0: 4.0)\n : <=2.5\n c1 (16.667: 3.0)\n : <=1.5 --> (15.0: 2.0)\n : >1.5 --> (20.0: 1.0)\n : >2.5 --> (10.0: 1.0)\n: f --> (30.0: 2.0)' self.assertEqual(reg_str, res)
def test_SimpleTree_regression(self): lrn = SimpleTreeReg() clf = lrn(self.data_reg) p = clf(self.data_reg) self.assertEqual(p.shape, (self.N, ))