def test_basic_tree(self): tree = tree_lib.Tree( node_lib.NonLeafNode( condition=condition_lib.NumericalHigherThanCondition( feature=dataspec_lib.SimpleColumnSpec( name="f1", type=dataspec_lib.ColumnType.NUMERICAL), threshold=1.5, missing_evaluation=False), pos_child=node_lib.LeafNode(value=value_lib.RegressionValue( value=5.0, num_examples=10, standard_deviation=1.0)), neg_child=node_lib.LeafNode(value=value_lib.ProbabilityValue( probability=[0.5, 0.4, 0.1], num_examples=10)))) tree_repr = repr(tree) logging.info("Tree repr:\n%s", tree_repr) # The "repr" is a single line that does not contain any line return. self.assertNotIn("\n", tree_repr) logging.info("Tree str:\n%s", tree) pretty = tree.pretty() logging.info("Pretty:\n%s", pretty) self.assertEqual( pretty, """(f1 >= 1.5; miss=False, score=None) ├─(pos)─ RegressionValue(value=5.0,sd=1.0,n=10) └─(neg)─ ProbabilityValue([0.5, 0.4, 0.1],n=10) """)
def test_core_value_to_value_classifier(self): core_node = decision_tree_pb2.Node() core_node.classifier.distribution.counts[:] = [0.0, 8.0, 2.0] core_node.classifier.distribution.sum = 10.0 self.assertEqual( value_lib.core_value_to_value(core_node), value_lib.ProbabilityValue(probability=[0.8, 0.2], num_examples=10))
def test_non_leaf_with_children(self): node = node_lib.NonLeafNode( condition=condition_lib.NumericalHigherThanCondition( feature=dataspec_lib.SimpleColumnSpec( name="f1", type=dataspec_lib.ColumnType.NUMERICAL), threshold=1.5, missing_evaluation=False), pos_child=node_lib.LeafNode(value=value_lib.RegressionValue( value=5.0, num_examples=10, standard_deviation=1.0)), neg_child=node_lib.LeafNode(value=value_lib.ProbabilityValue( probability=[0.5, 0.4, 0.1], num_examples=10))) logging.info("node:\n%s", node)
def test_basic_tree_with_label_classes(self): tree = tree_lib.Tree(node_lib.NonLeafNode( condition=condition_lib.NumericalHigherThanCondition( feature=dataspec_lib.SimpleColumnSpec( name="f1", type=dataspec_lib.ColumnType.NUMERICAL), threshold=1.5, missing_evaluation=False), pos_child=node_lib.LeafNode(value=value_lib.RegressionValue( value=5.0, num_examples=10, standard_deviation=1.0)), neg_child=node_lib.LeafNode(value=value_lib.ProbabilityValue( probability=[0.5, 0.4, 0.1], num_examples=10))), label_classes=["x", "y", "z"]) plot = model_plotter.plot_tree(tree=tree) self._save_plot(plot)
def test_probability(self): value = value_lib.ProbabilityValue(probability=[0.5, 0.4, 0.1], num_examples=10) logging.info("value:\n%s", value)
def test_stump_with_label_classes(self): tree = tree_lib.Tree( node_lib.LeafNode(value=value_lib.ProbabilityValue( probability=[0.5, 0.4, 0.1], num_examples=10)), label_classes=["a", "b", "c"]) logging.info("Tree:\n%s", tree)