예제 #1
0
def test_best_split_with_combination_combining_if_too_small():
    """
    Test passing in a perfect split data, with a single catagory merges needed
    """
    arr = np.array(([1] * 5) + ([2] * 10))
    orig_arr = arr.copy()
    ndarr = np.array(([1, 2, 3] * 5) + ([2, 2, 3] * 3) + ([3, 2, 3] * 5) +
                     [1, 2, 3] * 2).reshape(15, 3)
    orig_ndarr = ndarr.copy()
    tree = CHAID.Tree.from_numpy(ndarr,
                                 arr,
                                 min_child_node_size=5,
                                 alpha_merge=0.055)

    split = tree.generate_best_split(tree.vectorised_array, tree.observed)
    assert list_ordered_equal(
        ndarr, orig_ndarr
    ), 'Calling chaid should have no side affects for original numpy arrays'
    assert list_ordered_equal(
        arr, orig_arr
    ), 'Calling chaid should have no side affects for original numpy arrays'
    assert split.column_id == 0, 'Identifies correct column to split on'
    assert list_unordered_equal(
        split.split_map, [[1], [2, 3]]), 'Correctly identifies categories'
    assert list_unordered_equal(split.surrogates,
                                []), 'No surrogates should be generated'
    assert split.p < 0.055
    def test_possible_groups(self):
        """ Ensure possible groups are only adjacent numbers  """
        groupings = list(self.col.possible_groupings())
        possible_groupings = [(1, 2), (2, 3), (3, 4), (4, 5)]
        assert list_unordered_equal(possible_groupings, groupings), 'Without NaNs, with groups are identified, possible grouping are incorrectly identified.'

        groups = list(self.col.groups())
        actual_groups = [[1], [2], [3], [4], [5], [10]]
        assert list_unordered_equal(actual_groups, groups), 'Without NaNs, before any groups are identified, actual groupings are incorrectly reported'
    def test_groups_after_grouping(self):
        """ Ensure a copy actually happens when deep_copy is called """
        self.col.group(3, 4)
        self.col.group(3, 2)

        groupings = list(self.col.possible_groupings())
        possible_groupings = [(1, 3), (3, 5)]
        assert list_unordered_equal(possible_groupings, groupings), 'Without NaNs, with groups are identified, possible grouping are incorrectly identified.'

        groups = list(self.col.groups())
        actual_groups = [[1], [2, 3, 4], [5], [10]]
        assert list_unordered_equal(actual_groups, groups), 'Without NaNs, before any groups are identified, actual groupings are incorrectly reported'
    def test_possible_groups(self):
        """ Ensure possible groups are only adjacent numbers  """
        metadata = self.col.metadata
        groupings = [(metadata[x], metadata[y]) for x, y in self.col.possible_groupings()]
        possible_groupings = [
            (1.0, 2.0), (2.0, 3.0), (3.0, 4.0), (4.0, 5.0), (1.0, '<missing>'), (2.0, '<missing>'), (3.0, '<missing>'),
            (4.0, '<missing>'), (5.0, '<missing>'), (10.0, '<missing>')
        ]
        assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, before any groups are identified, possible grouping are incorrectly calculated.'

        groups = list(self.col.groups())
        groups = [[self.col.metadata[i] for i in group] for group in self.col.groups()]
        actual_groups = [[1.0], [2.0], [3.0], [4.0], [5.0], ['<missing>'], [10.0]]
        assert list_unordered_equal(actual_groups, groups), 'With NaNs, before any groups are identified, actual groupings are incorrectly reported'
    def test_groups_after_grouping(self):
        """ Ensure possible groups are only adjacent numbers after identifing some groups """
        self.col.group(3.0, 4.0)
        self.col.group(3.0, 2.0)

        groupings = [(self.col.metadata[x], self.col.metadata[y]) for x, y in self.col.possible_groupings()]
        possible_groupings = [
            (1.0, 3.0), (3.0, 5.0), (1.0, '<missing>'), (3.0, '<missing>'), (5.0, '<missing>'), (10.0, '<missing>')
        ]
        assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, with groups are identified, possible grouping incorrectly identified.'

        groups = [[self.col.metadata[i] for i in group] for group in self.col.groups()]
        actual_groups = [[1.0], [2.0, 3.0, 4.0], [5.0], [10.0], ['<missing>']]
        assert list_unordered_equal(actual_groups, groups), 'With NaNs, with groups identified, actual groupings are incorrectly reported'
    def test_groups_after_copy_with_nan(self):
        """ Ensure possible groups are only adjacent numbers after identifing some groups containing nans"""
        self.col.group(3.0, 4.0)
        self.col.group(3.0, self.col._nan)
        self.col.group(3.0, 2.0)
        col = self.col.deep_copy()

        groupings = [(col.metadata[x], col.metadata[y]) for x, y in col.possible_groupings()]
        possible_groupings = [
            (1.0, 3.0), (3.0, 5.0)
        ]
        assert list_unordered_equal(possible_groupings, groupings), 'With NaNs, with groups containing nan identified, possible grouping incorrectly identified.'

        groups = [[col.metadata[i] for i in group] for group in col.groups()]
        actual_groups = [[1.0], [2.0, 3.0, 4.0, '<missing>'], [5.0], [10.0]]
        assert list_unordered_equal(actual_groups, groups), 'With NaNs, with groups containing nan identified, actual groupings are incorrectly reported'
예제 #7
0
 def test_single_path(self):
     expected_rules = [
         {
             'node': 3,
             'rules': [
                 {'data': [3], 'variable': 'c'},
                 {'data': [2], 'variable': 'a'}
             ]
         }
     ]
     assert list_unordered_equal(self.tree.classification_rules(self.last_node), expected_rules), "Couldn't find path to determine rules from specified node"
예제 #8
0
def test_best_split_unique_values():
    """
    Test passing in a perfect split data, with no catagory merges needed
    """
    arr = np.array(([1] * 5) + ([2] * 5))
    orig_arr = arr.copy()
    ndarr = np.array(([1, 2, 3] * 5) + ([2, 2, 3] * 5)).reshape(10, 3)
    orig_ndarr = ndarr.copy()
    tree = CHAID.Tree.from_numpy(ndarr, arr, min_child_node_size=0)

    split = tree.generate_best_split(
        tree.vectorised_array,
        tree.observed
    )
    assert list_ordered_equal(ndarr, orig_ndarr), 'Calling chaid should have no side affects for original numpy arrays'
    assert list_ordered_equal(arr, orig_arr), 'Calling chaid should have no side affects for original numpy arrays'
    assert split.column_id == 0, 'Identifies correct column to split on'
    assert list_unordered_equal(split.split_map, [[1], [2]]), 'Correctly identifies catagories'
    assert list_unordered_equal(split.surrogates, []), 'No surrogates should be generated'
    assert split.p < 0.015
예제 #9
0
def test_best_split_with_combination_combining_if_too_small():
    """
    Test passing in a perfect split data, with a single catagory merges needed
    """
    arr = np.array(([1] * 5) + ([2] * 10))
    orig_arr = arr.copy()
    ndarr = np.array(([1, 2, 3] * 5) + ([2, 2, 3] * 3) + ([3, 2, 3] * 5) + [1, 2, 3] * 2).reshape(15, 3)
    orig_ndarr = ndarr.copy()
    tree = CHAID.Tree.from_numpy(ndarr, arr, min_child_node_size=5, alpha_merge=0.055)

    split = tree.generate_best_split(
        tree.vectorised_array,
        tree.observed
    )
    assert list_ordered_equal(ndarr, orig_ndarr), 'Calling chaid should have no side affects for original numpy arrays'
    assert list_ordered_equal(arr, orig_arr), 'Calling chaid should have no side affects for original numpy arrays'
    assert split.column_id == 0, 'Identifies correct column to split on'
    assert list_unordered_equal(split.split_map, [[1], [2, 3]]), 'Correctly identifies categories'
    assert list_unordered_equal(split.surrogates, []), 'No surrogates should be generated'
    assert split.p < 0.055
예제 #10
0
def test_best_split_unique_values():
    """
    Test passing in a perfect split data, with no catagory merges needed
    """
    arr = np.array(([1] * 5) + ([2] * 5))
    orig_arr = arr.copy()
    ndarr = np.array(([1, 2, 3] * 5) + ([2, 2, 3] * 5)).reshape(10, 3)
    orig_ndarr = ndarr.copy()
    tree = CHAID.Tree.from_numpy(ndarr, arr, min_child_node_size=0)

    split = tree.generate_best_split(
        tree.vectorised_array,
        tree.observed
    )
    assert list_ordered_equal(ndarr, orig_ndarr), 'Calling chaid should have no side affects for original numpy arrays'
    assert list_ordered_equal(arr, orig_arr), 'Calling chaid should have no side affects for original numpy arrays'
    assert split.column_id == 0, 'Identifies correct column to split on'
    assert list_unordered_equal(split.split_map, [[1], [2]]), 'Correctly identifies catagories'
    assert list_unordered_equal(split.surrogates, []), 'No surrogates should be generated'
    assert split.p < 0.015
예제 #11
0
 def test_single_path(self):
     expected_rules = [
         {
             'node': 3,
             'rules': [
                 {'data': [3], 'variable': 'c'},
                 {'data': [2], 'variable': 'a'}
             ]
         }
     ]
     assert list_unordered_equal(self.tree.classification_rules(self.last_node), expected_rules), "Couldn't find path to determine rules from specified node"
예제 #12
0
 def test_all_paths(self):
     expected_rules = [
         {
             'node': 3,
             'rules': [
                 {'data': [3], 'variable': 'c'},
                 {'data': [2], 'variable': 'a'}
             ]
         },
         {
             'node': 1,
             'rules': [
                 {'data': [1], 'variable': 'a'}
             ]
         }
     ]
     assert list_unordered_equal(self.tree.classification_rules(), expected_rules), "Couldn't find path to determine rules from all terminal nodes"
예제 #13
0
 def test_all_paths(self):
     expected_rules = [
         {
             'node': 3,
             'rules': [
                 {'data': [3], 'variable': 'c'},
                 {'data': [2], 'variable': 'a'}
             ]
         },
         {
             'node': 1,
             'rules': [
                 {'data': [1], 'variable': 'a'}
             ]
         }
     ]
     assert list_unordered_equal(self.tree.classification_rules(), expected_rules), "Couldn't find path to determine rules from all terminal nodes"