コード例 #1
0
def test_zero_subbed_weighted_ndarry():
    """
    Test how the split works when 0 independent categorical variable chooses a dependent categorical variable for the weighted case.
    In this instance, a very small float is assigned to the 0 value
    """
    gender = np.array(
        [0, 0, 1, 1, 0, 0, 1, 1, 0, 0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 1])
    income = np.array(
        [0, 0, 1, 1, 2, 0, 1, 1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0])
    weighting = np.array(([0.9] * int(len(gender) / 2.0)) +
                         ([1.9] * int(len(gender) / 2.0)))

    ndarr = np.transpose(np.vstack([gender]))

    tree = CHAID.Tree(ndarr, income, alpha_merge=0.9, weights=weighting)

    split = tree.generate_best_split(tree.vectorised_array, tree.observed,
                                     weighting)

    assert round(split.chi, 4) == 14.5103
    assert round(split.p, 4) == 0.0007
コード例 #2
0
def test_best_split_unique_values():
    """
    Test passing in a perfect split data, with no catagory merges needed
    """
    arr = np.array(([1] * 5) + ([2] * 5))
    orig_arr = arr.copy()
    ndarr = np.array(([1, 2, 3] * 5) + ([2, 2, 3] * 5)).reshape(10, 3)
    orig_ndarr = ndarr.copy()
    tree = CHAID.Tree(ndarr, arr)

    split = tree.generate_best_split(tree.vectorised_array, tree.observed)
    assert list_ordered_equal(
        ndarr, orig_ndarr
    ), 'Calling chaid should have no side affects for original numpy arrays'
    assert list_ordered_equal(
        arr, orig_arr
    ), 'Calling chaid should have no side affects for original numpy arrays'
    assert split.column_id == 0, 'Identifies correct column to split on'
    assert list_unordered_equal(split.split_map,
                                [[1], [2]]), 'Correctly identifies catagories'
    assert list_unordered_equal(split.surrogates,
                                []), 'No surrogates should be generated'
    assert split.p < 0.015
コード例 #3
0
ファイル: test_tree.py プロジェクト: tongliuliu/CHAID
 def setUp(self):
     invalid_split = CHAID.Split(None, None, 0, 1, 0)
     self.tree = CHAID.Tree.from_numpy(np.array([[1]]), np.array([1]))
     self.tree._tree_store = [
         CHAID.Node(
             node_id=0,
             split=CHAID.Split('a', [[1], [2]], 1, 0.2, 2)
         ),
         CHAID.Node(node_id=1, split=invalid_split, choices=[1], parent=0),
         CHAID.Node(
             node_id=2,
             split=CHAID.Split('c', [[3]], 1, 0.2, 2),
             choices=[2],
             parent=0
         )
     ]
     self.last_node = CHAID.Node(
         node_id=3,
         split=invalid_split,
         choices=[3],
         parent=2
     )
     self.tree._tree_store.append(self.last_node)
コード例 #4
0
 def setUp(self):
     """ Setup for grouping tests """
     arr = np.array([1.0, 2.0, 3.0, 3.0, 3.0, 3.0, 4.0, 5.0, 10.0])
     self.col = CHAID.OrdinalColumn(arr)
コード例 #5
0
 def setUp(self):
     """ Setup for copy tests"""
     arr = np.array([1, 2, 3, 3, 3, 3])
     self.orig = CHAID.OrdinalColumn(arr)
     self.copy = self.orig.deep_copy()
コード例 #6
0
 def setUp(self):
     """ Setup for copy tests"""
     # Use string so numpy array dtype is object and may store references
     arr = np.array(['5.0', '10.0'])
     self.orig = CHAID.NominalColumn(arr)
     self.copy = self.orig.deep_copy()
コード例 #7
0
 def setUp(self):
     """ Set up for tree generation tests """
     arr = np.array(([1] * 5) + ([2] * 5))
     ndarr = np.array(([1, 2, 3] * 5) + ([2, 2, 3] * 5)).reshape(10, 3)
     self.tree = CHAID.Tree(ndarr, arr)
コード例 #8
0
ファイル: test_split.py プロジェクト: codingblg/CHAID_phi
def test_invalid_split():
    """ Test properties when split invalid """
    split = CHAID.Split(None, None, None, 1, 0)
    assert split.invalid_reason == None
    assert split.column == None
コード例 #9
0
ファイル: test_stats.py プロジェクト: codingblg/CHAID_phi
 def setUp(self):
     """ Setup test data for continuous data """
     self.random_arr = np.array(
        [0.23198952,  0.26550251,  0.96461057,  0.13733767,  0.76674088,
         0.60637166,  0.18822053,  0.78785506,  0.47786053,  0.44448984,
         0.88632344,  0.94060264,  0.52900520,  0.68301794,  0.00485769,
         0.09299505,  0.41767638,  0.22345506,  0.61899892,  0.53763263,
         0.41424529,  0.87527060,  0.10843391,  0.22902548,  0.52043049,
         0.82396842,  0.64215622,  0.42827082,  0.76920710,  0.27736853,
         0.95756523,  0.45140920,  0.12405161,  0.53774033,  0.72198885,
         0.37880053,  0.93554955,  0.44434796,  0.62834896,  0.02788777,
         0.30288893,  0.07198041,  0.59731867,  0.63485262,  0.79936557,
         0.41154027,  0.82900816,  0.49216809,  0.56649288,  0.26539558,
         0.12304309,  0.03233878,  0.64612524,  0.69844021,  0.30560065,
         0.05408900,  0.31020185,  0.93087523,  0.27952452,  0.57186781,
         0.36214135,  0.34114557,  0.82028983,  0.29795183,  0.21028335,
         0.41612748,  0.24781879,  0.19125266,  0.17214954,  0.44039645,
         0.84397111,  0.91060384,  0.70898285,  0.27049457,  0.15502956,
         0.47580771,  0.21507488,  0.68243381,  0.56233427,  0.22376202,
         0.76630117,  0.00162193,  0.15057895,  0.10145753,  0.69406461,
         0.81280760,  0.79726816,  0.42523241,  0.56025856,  0.10287649,
         0.53337746,  0.82185783,  0.38270064,  0.77411309,  0.01754383,
         0.84690273,  0.20057135,  0.37194360,  0.24657089,  0.91520048,
         0.65575302,  0.03220805,  0.71449568,  0.97194268,  0.94031990,
         0.61484448,  0.46961425,  0.38495625,  0.41865701,  0.81394666,
         0.57147433,  0.33414233,  0.13847757,  0.31316325,  0.04371212,
         0.36556674,  0.56316862,  0.66761528,  0.02491041,  0.12124478]
     )
     self.normal_arr = np.array([
         215.74655491,  237.0905247 ,  193.72021408,  152.89363815,
         175.36670032,  232.59086085,  204.20219942,  248.99321897,
         267.95686148,  165.7204985 ,  177.38110221,  220.40618705,
         262.71893125,  240.00774431,  210.85572027,  255.06583994,
         232.85274614,  274.71932373,  186.83175676,  241.47832856,
         294.98781486,  190.82037054,  143.7991682 ,  170.32090888,
         207.20320791,  208.10226642,  187.09923858,  178.9242382 ,
         155.17266333,  140.69923988,  210.80029533,  193.85525698,
         232.69854217,  230.4408611 ,  149.34523942,  303.6243051 ,
         171.1562868 ,  185.24131426,  195.80616026,  224.38213062,
         261.77203837,  170.81218927,  216.37943211,  265.25650174,
         203.3098626 ,  229.84982086,  212.14777791,  265.25335911,
         296.11334434,  242.40424522,  270.30264815,   77.97401496,
         176.80382943,  156.35135782,  155.29031942,  262.11885208,
         161.33251252,  256.05120377,  158.32542953,  189.07183278,
         155.72524265,  244.68956731,  286.68689241,   94.08648606,
         253.80300049,  161.17371005,  116.94584491,  182.88557535,
         182.85752412,  253.42111371,  131.25146323,  264.86407965,
         197.3742505 ,  296.95506279,  221.01600673,  234.04694958,
         154.42957223,  176.94139196,  200.59554949,  170.4040058 ,
         229.39358115,  127.43357367,  249.09735255,  227.90731765,
         238.9667355 ,  163.83410357,  194.88998826,  134.49013182,
         154.54356067,  254.19699384,  143.93816979,  256.11031829,
         186.56096688,  178.40462838,  159.79032932,  187.7542398 ,
         267.18537402,  190.99969385,  130.30080584,  216.12902248,
         247.8707783 ,  246.49016072,  275.3636918 ,  165.69987612,
         181.16709806,  193.87951446,  156.03720504,  221.44032879,
         182.21405831,  119.22571297,  219.14946203,  140.358539  ,
         210.5826685 ,  256.57132523,  244.82587339,  153.26377344,
         198.44006972,  172.6057332 ,  140.26518016,  171.32162943]
     )
     self.wt = np.array(([1.0] * 60) + ([1.2] * 60))
     ndarr = np.array(([2, 3] * 20) + ([2, 5] * 20) + ([3, 4] * 19) + [2, 3] + [1, 2, 5] * 80 + [1, 2, 3] * 40).reshape(120, 4)
     self.ndarr = [CHAID.NominalColumn(ndarr[:, i]) for i in range(ndarr.shape[1])]
     self.stats_random_data = CHAID.Stats(0.5, 10, .95, self.random_arr)
     self.stats_normal_data = CHAID.Stats(0.5, 10, .95, self.normal_arr)