Ejemplo n.º 1
0
    def test_basic_tree(self):
        tree = tree_lib.Tree(
            node_lib.NonLeafNode(
                condition=condition_lib.NumericalHigherThanCondition(
                    feature=dataspec_lib.SimpleColumnSpec(
                        name="f1", type=dataspec_lib.ColumnType.NUMERICAL),
                    threshold=1.5,
                    missing_evaluation=False),
                pos_child=node_lib.LeafNode(value=value_lib.RegressionValue(
                    value=5.0, num_examples=10, standard_deviation=1.0)),
                neg_child=node_lib.LeafNode(value=value_lib.ProbabilityValue(
                    probability=[0.5, 0.4, 0.1], num_examples=10))))

        tree_repr = repr(tree)
        logging.info("Tree repr:\n%s", tree_repr)
        # The "repr" is a single line that does not contain any line return.
        self.assertNotIn("\n", tree_repr)

        logging.info("Tree str:\n%s", tree)

        pretty = tree.pretty()
        logging.info("Pretty:\n%s", pretty)

        self.assertEqual(
            pretty, """(f1 >= 1.5; miss=False, score=None)
    ├─(pos)─ RegressionValue(value=5.0,sd=1.0,n=10)
    └─(neg)─ ProbabilityValue([0.5, 0.4, 0.1],n=10)
""")
Ejemplo n.º 2
0
 def test_core_value_to_value_classifier(self):
     core_node = decision_tree_pb2.Node()
     core_node.classifier.distribution.counts[:] = [0.0, 8.0, 2.0]
     core_node.classifier.distribution.sum = 10.0
     self.assertEqual(
         value_lib.core_value_to_value(core_node),
         value_lib.ProbabilityValue(probability=[0.8, 0.2],
                                    num_examples=10))
Ejemplo n.º 3
0
 def test_non_leaf_with_children(self):
     node = node_lib.NonLeafNode(
         condition=condition_lib.NumericalHigherThanCondition(
             feature=dataspec_lib.SimpleColumnSpec(
                 name="f1", type=dataspec_lib.ColumnType.NUMERICAL),
             threshold=1.5,
             missing_evaluation=False),
         pos_child=node_lib.LeafNode(value=value_lib.RegressionValue(
             value=5.0, num_examples=10, standard_deviation=1.0)),
         neg_child=node_lib.LeafNode(value=value_lib.ProbabilityValue(
             probability=[0.5, 0.4, 0.1], num_examples=10)))
     logging.info("node:\n%s", node)
Ejemplo n.º 4
0
 def test_basic_tree_with_label_classes(self):
     tree = tree_lib.Tree(node_lib.NonLeafNode(
         condition=condition_lib.NumericalHigherThanCondition(
             feature=dataspec_lib.SimpleColumnSpec(
                 name="f1", type=dataspec_lib.ColumnType.NUMERICAL),
             threshold=1.5,
             missing_evaluation=False),
         pos_child=node_lib.LeafNode(value=value_lib.RegressionValue(
             value=5.0, num_examples=10, standard_deviation=1.0)),
         neg_child=node_lib.LeafNode(value=value_lib.ProbabilityValue(
             probability=[0.5, 0.4, 0.1], num_examples=10))),
                          label_classes=["x", "y", "z"])
     plot = model_plotter.plot_tree(tree=tree)
     self._save_plot(plot)
Ejemplo n.º 5
0
 def test_probability(self):
     value = value_lib.ProbabilityValue(probability=[0.5, 0.4, 0.1],
                                        num_examples=10)
     logging.info("value:\n%s", value)
Ejemplo n.º 6
0
 def test_stump_with_label_classes(self):
     tree = tree_lib.Tree(
         node_lib.LeafNode(value=value_lib.ProbabilityValue(
             probability=[0.5, 0.4, 0.1], num_examples=10)),
         label_classes=["a", "b", "c"])
     logging.info("Tree:\n%s", tree)