예제 #1
0
 def test_null_split_returns_all_values(self):
     data = make_bartpy_data(
         pd.DataFrame({
             "a": [1, 2]
         }).values, np.array([1, 2]))
     split = Split(data)
     conditioned_data = split.data
     self.assertListEqual(list(data.X.get_column(0)),
                          list(conditioned_data.X.get_column(0)))
예제 #2
0
 def setUp(self):
     self.data = make_bartpy_data(
         pd.DataFrame({"a": np.random.normal(size=1000)}),
         np.array(np.random.normal(size=1000)))
     self.d = LeafNode(Split(self.data))
     self.e = LeafNode(Split(self.data))
     self.c = DecisionNode(Split(self.data), self.d, self.e)
     self.b = LeafNode(Split(self.data))
     self.a = DecisionNode(Split(self.data), self.b, self.c)
     self.tree = Tree([self.a, self.b, self.c, self.d, self.e])
예제 #3
0
 def setUp(self):
     self.data = make_bartpy_data(pd.DataFrame({"a": [1, 2]}),
                                  np.array([1, 2]),
                                  normalize=False)
     self.d = LeafNode(Split(self.data))
     self.e = LeafNode(Split(self.data))
     self.c = DecisionNode(Split(self.data), self.d, self.e)
     self.b = LeafNode(Split(self.data))
     self.a = DecisionNode(Split(self.data), self.b, self.c)
     self.tree = Tree([self.a, self.b, self.c, self.d, self.e])
예제 #4
0
 def test_single_condition_data(self):
     data = make_bartpy_data(
         pd.DataFrame({
             "a": [1, 2]
         }).values, np.array([1, 2]))
     left_condition, right_condition = SplitCondition(0, 1,
                                                      le), SplitCondition(
                                                          0, 1, gt)
     left_split, right_split = Split(data) + left_condition, Split(
         data) + right_condition
     self.assertListEqual([1], list(left_split.data.X.get_column(0)))
     self.assertListEqual([2], list(right_split.data.X.get_column(0)))
예제 #5
0
    def test_combined_condition_data(self):
        data = make_bartpy_data(
            pd.DataFrame({
                "a": [1, 2, 3, 4]
            }).values, np.array([1, 2, 1, 1]))

        first_left_condition, first_right_condition = SplitCondition(
            0, 3, le), SplitCondition(0, 3, gt)
        second_left_condition, second_right_condition = SplitCondition(
            0, 1, le), SplitCondition(0, 1, gt)

        split = Split(data)
        updated_split = split + first_left_condition + second_right_condition
        conditioned_data = updated_split.data
        self.assertListEqual([2, 3], list(conditioned_data.X.get_column(0)))
예제 #6
0
    def test_most_recent_split(self):
        data = make_bartpy_data(
            pd.DataFrame({
                "a": [1, 2, 3, 4]
            }).values, np.array([1, 2, 1, 1]))

        first_left_condition, first_right_condition = SplitCondition(
            0, 3, le), SplitCondition(0, 3, gt)
        second_left_condition, second_right_condition = SplitCondition(
            0, 1, le), SplitCondition(0, 1, gt)

        split = Split(data)
        updated_split = split + first_left_condition + second_right_condition
        self.assertEqual(
            (split + first_left_condition).most_recent_split_condition(),
            first_left_condition)
        self.assertEqual(updated_split.most_recent_split_condition(),
                         second_right_condition)
예제 #7
0
 def setUp(self):
     self.X = np.random.normal(size=20)
     self.y = self.X + np.random.normal(scale=0.1, size=20)
     self.data = make_bartpy_data(pd.DataFrame({"a": self.X}),
                                  self.y,
                                  normalize=False)