コード例 #1
0
ファイル: test_split.py プロジェクト: vishalbelsare/bartpy
    def test_multiple_variables(self):
        conditions = [SplitCondition(0, 2, gt), SplitCondition(1, 1, gt)]

        X = self.X[:, 0].reshape(3, 2)
        combined_condition = CombinedCondition([0, 1], conditions)
        self.assertListEqual(list(combined_condition.condition(X)),
                             [False, True, True])
コード例 #2
0
ファイル: test_split.py プロジェクト: vishalbelsare/bartpy
    def test_multiple_conditions(self):
        conditions = [SplitCondition(0, 2, gt), SplitCondition(0, 5, le)]

        combined_condition = CombinedCondition([0], conditions)
        self.assertEqual(combined_condition.variables[0].min_value, 2)
        self.assertEqual(combined_condition.variables[0].max_value, 5)
        self.assertListEqual(list(combined_condition.condition(self.X)),
                             [False, False, True, False, True, True])
コード例 #3
0
ファイル: test_split.py プロジェクト: stjordanis/bartpy
 def test_single_condition_data(self):
     data = Data(pd.DataFrame({"a": [1, 2]}).values, np.array([1, 2]))
     left_condition, right_condition = SplitCondition(0, 1,
                                                      le), SplitCondition(
                                                          0, 1, gt)
     left_split, right_split = Split(data) + left_condition, Split(
         data) + right_condition
     self.assertListEqual([1], list(left_split.data.X[:, 0]))
     self.assertListEqual([2], list(right_split.data.X[:, 0]))
コード例 #4
0
ファイル: test_split.py プロジェクト: vishalbelsare/bartpy
    def test_combined_condition_data(self):
        data = make_bartpy_data(
            pd.DataFrame({
                "a": [1, 2, 3, 4]
            }).values, np.array([1, 2, 1, 1]))

        first_left_condition, first_right_condition = SplitCondition(
            0, 3, le), SplitCondition(0, 3, gt)
        second_left_condition, second_right_condition = SplitCondition(
            0, 1, le), SplitCondition(0, 1, gt)

        split = Split(data)
        updated_split = split + first_left_condition + second_right_condition
        conditioned_data = updated_split.data
        self.assertListEqual([2, 3], list(conditioned_data.X.get_column(0)))
コード例 #5
0
def sample_split_condition(node: LeafNode) -> Optional[Tuple[SplitCondition, SplitCondition]]:
    """
    Randomly sample a splitting rule for a particular leaf node
    Works based on two random draws

      - draw a node to split on based on multinomial distribution
      - draw an observation within that variable to split on

    Returns None if there isn't a possible non-degenerate split
    """
    split_variable = np.random.choice(list(node.split.data.splittable_variables()))
    split_value = node.data.random_splittable_value(split_variable)
    if split_value is None:
        return None
    return SplitCondition(split_variable, split_value, le), SplitCondition(split_variable, split_value, gt)
コード例 #6
0
ファイル: test_split.py プロジェクト: vishalbelsare/bartpy
    def test_most_recent_split(self):
        data = make_bartpy_data(
            pd.DataFrame({
                "a": [1, 2, 3, 4]
            }).values, np.array([1, 2, 1, 1]))

        first_left_condition, first_right_condition = SplitCondition(
            0, 3, le), SplitCondition(0, 3, gt)
        second_left_condition, second_right_condition = SplitCondition(
            0, 1, le), SplitCondition(0, 1, gt)

        split = Split(data)
        updated_split = split + first_left_condition + second_right_condition
        self.assertEqual(
            (split + first_left_condition).most_recent_split_condition(),
            first_left_condition)
        self.assertEqual(updated_split.most_recent_split_condition(),
                         second_right_condition)
コード例 #7
0
    def setUp(self):
        self.data = Data(
            pd.DataFrame({
                "a": [1, 2, 3],
                "b": [1, 2, 3]
            }).values, np.array([1, 2, 3]))

        self.a = split_node(LeafNode(Split(
            self.data)), (SplitCondition(0, 1, le), SplitCondition(0, 1, gt)))
        self.b = self.a.left_child
        self.x = self.a.right_child
        self.tree = Tree([self.a, self.b, self.x])

        self.c = split_node(
            self.a._right_child,
            (SplitCondition(1, 2, le), SplitCondition(1, 2, gt)))
        mutate(self.tree, TreeMutation("grow", self.x, self.c))

        self.d = self.c.left_child
        self.e = self.c.right_child
コード例 #8
0
    def setUp(self):
        X = format_covariate_matrix(
            pd.DataFrame({
                "a": [1, 2, 3],
                "b": [1, 2, 3]
            }))
        self.data = Data(X, np.array([1, 2, 3]).astype(float))

        self.a = split_node(LeafNode(Split(
            self.data)), (SplitCondition(0, 1, le), SplitCondition(0, 1, gt)))
        self.b = self.a.left_child
        self.x = self.a.right_child
        self.tree = Tree([self.a, self.b, self.x])

        self.c = split_node(
            self.a._right_child,
            (SplitCondition(1, 2, le), SplitCondition(1, 2, gt)))
        mutate(self.tree, TreeMutation("grow", self.x, self.c))

        self.d = self.c.left_child
        self.e = self.c.right_child
コード例 #9
0
ファイル: test_split.py プロジェクト: vishalbelsare/bartpy
 def test_single_condition(self):
     condition = SplitCondition(0, 3, gt)
     combined_condition = CombinedCondition([0], [condition])
     self.assertListEqual(list(combined_condition.condition(self.X)),
                          [False, False, True, True, False, True])