Пример #1
0
def test_fix_metadata_if_passed_in():
    arr = np.array([1.0, 2.0, 3.0])
    nominal = CHAID.NominalColumn(arr,
                                  metadata={
                                      1.0: 'Cat',
                                      2.0: 'Haz',
                                      3.0: 'Cheezburger'
                                  })
    assert [nominal.metadata[x]
            for x in nominal.arr] == ['Cat', 'Haz', 'Cheezburger']
Пример #2
0
def test_chaid_vector_with_ints_and_nan():
    """
    Check that the metadata is correct when NominalColumns are created from ints
    """
    arr = np.array([1, 2, np.nan])
    vector = CHAID.NominalColumn(arr)

    assert np.array_equal(vector.arr, np.array([0, 1, -1])), \
        'The indices are correctly substituted'
    assert vector.metadata == {0: 1, 1: 2, -1: '<missing>'}, \
        'The metadata is formed correctly'
Пример #3
0
def test_all_combinations():
    arr = np.array([1.0, 2.0, 3.0, 4.0])
    nominal = CHAID.NominalColumn(arr)
    assert [i for i in nominal.all_combinations()
            ] == [[[0.0], [1.0, 2.0, 3.0]], [[0.0, 1.0], [2.0, 3.0]],
                  [[1.0], [0.0, 2.0, 3.0]], [[0.0], [1.0], [2.0, 3.0]],
                  [[0.0, 1.0, 2.0], [3.0]], [[1.0, 2.0], [0.0, 3.0]],
                  [[0.0], [1.0, 2.0], [3.0]], [[0.0, 2.0], [1.0, 3.0]],
                  [[2.0], [0.0, 1.0, 3.0]], [[0.0], [2.0], [1.0, 3.0]],
                  [[0.0, 1.0], [2.0], [3.0]], [[1.0], [0.0, 2.0], [3.0]],
                  [[1.0], [2.0], [0.0, 3.0]], [[0.0], [1.0], [2.0], [3.0]]]
Пример #4
0
def test_chaid_vector_converts_strings():
    """
    Check that the metadata is correct when NominalColumns are created from
    strings
    """
    arr = np.array(['2', '4'])
    vector = CHAID.NominalColumn(arr)

    assert np.array_equal(vector.arr, np.array([0, 1])), \
        'The indices are correctly substituted'
    assert vector.metadata == {0: '2', 1: '4'}, \
        'The metadata is formed correctly'
Пример #5
0
def test_chaid_vector_with_dtype_object():
    """
    Check that the metadata is correct when NominalColumns are created from
    objects
    """
    arr = np.array([1, 2], dtype="object")
    vector = CHAID.NominalColumn(arr)

    assert np.array_equal(vector.arr, np.array([0, 1])), \
        'The indices are correctly substituted'
    assert vector.metadata == {0: 1, 1: 2}, \
        'The metadata is formed correctly'
Пример #6
0
def test_chaid_vector_with_floats():
    """
    Check that the metadata is correct when NominalColumns are created from
    floats
    """
    arr = np.array([1.0, 2.0])
    vector = CHAID.NominalColumn(arr)

    assert np.array_equal(vector.arr, np.array([0, 1])), \
        'The indices are correctly substituted'
    assert vector.metadata == {0: 1.0, 1: 2.0}, \
        'The metadata is formed correctly'
Пример #7
0
def test_column_stores_weights():
    """
    Tests that the columns store the weights when they are passed
    """
    arr = np.array([1.0, 2.0, 3.0])
    wt = np.array([2.0, 1.0, 0.25])
    nominal = CHAID.NominalColumn(arr, weights=wt)
    ordinal = CHAID.OrdinalColumn(arr, weights=wt)
    continuous = CHAID.ContinuousColumn(arr, weights=wt)
    assert (nominal.weights == wt).all()
    assert (ordinal.weights == wt).all()
    assert (continuous.weights == wt).all()
Пример #8
0
    def test_comparison_of_different_object_types(self):
        """
        Fix bug whereby floats were being passed into NominalColumn
        from `self.observed = NominalColumn(arr)` but as `dtype=object`
        resulting in `TypeError: unorderable types: int() > str()` in
        python 3.x only
        """
        input_list = [100, 'c', 13, 15, np.nan, np.nan]
        object_arr = np.array(input_list, dtype=object)
        vector = CHAID.NominalColumn(object_arr)

        assert [vector.metadata[x] for x in vector.arr
                ] == ['<missing>' if x != x else x for x in input_list]
Пример #9
0
def test_new_columns_constructor():
    """
    Test the new tree constructor that takes CHAID Columns as parameters
    """
    orientation = np.array([0,0,1,1,0,0,1,1,0,0,1,2,2,2,2,2,2,2,2,1])
    age = np.array([0,1,1,0,2,2,2,2,1,1,1,0,0,0,0,0,0,0,0,0])
    income = np.array([0,0,1,1,2,0,1,1,1,0,1,0,0,0,0,0,0,0,0,0])
    metadata = {0: '0-5', 1: '6-10', 2: '11-15'}
    cols = [
        CHAID.OrdinalColumn(orientation, name="orientation", metadata=metadata),
        CHAID.OrdinalColumn(age, name="age", metadata=metadata),
    ]
    tree = CHAID.Tree(cols, CHAID.NominalColumn(income), {'min_child_node_size': 1})
    assert tree.tree_store[0].split.groupings == "[['0-5'], ['6-10', '11-15']]"
Пример #10
0
 def setUp(self):
     """ Setup for copy tests"""
     # Use string so numpy array dtype is object and may store references
     arr = np.array(['5.0', '10.0'])
     self.orig = CHAID.NominalColumn(arr)
     self.copy = self.orig.deep_copy()
Пример #11
0
 def setUp(self):
     """ Setup test data for continuous data """
     self.random_arr = np.array(
        [0.23198952,  0.26550251,  0.96461057,  0.13733767,  0.76674088,
         0.60637166,  0.18822053,  0.78785506,  0.47786053,  0.44448984,
         0.88632344,  0.94060264,  0.52900520,  0.68301794,  0.00485769,
         0.09299505,  0.41767638,  0.22345506,  0.61899892,  0.53763263,
         0.41424529,  0.87527060,  0.10843391,  0.22902548,  0.52043049,
         0.82396842,  0.64215622,  0.42827082,  0.76920710,  0.27736853,
         0.95756523,  0.45140920,  0.12405161,  0.53774033,  0.72198885,
         0.37880053,  0.93554955,  0.44434796,  0.62834896,  0.02788777,
         0.30288893,  0.07198041,  0.59731867,  0.63485262,  0.79936557,
         0.41154027,  0.82900816,  0.49216809,  0.56649288,  0.26539558,
         0.12304309,  0.03233878,  0.64612524,  0.69844021,  0.30560065,
         0.05408900,  0.31020185,  0.93087523,  0.27952452,  0.57186781,
         0.36214135,  0.34114557,  0.82028983,  0.29795183,  0.21028335,
         0.41612748,  0.24781879,  0.19125266,  0.17214954,  0.44039645,
         0.84397111,  0.91060384,  0.70898285,  0.27049457,  0.15502956,
         0.47580771,  0.21507488,  0.68243381,  0.56233427,  0.22376202,
         0.76630117,  0.00162193,  0.15057895,  0.10145753,  0.69406461,
         0.81280760,  0.79726816,  0.42523241,  0.56025856,  0.10287649,
         0.53337746,  0.82185783,  0.38270064,  0.77411309,  0.01754383,
         0.84690273,  0.20057135,  0.37194360,  0.24657089,  0.91520048,
         0.65575302,  0.03220805,  0.71449568,  0.97194268,  0.94031990,
         0.61484448,  0.46961425,  0.38495625,  0.41865701,  0.81394666,
         0.57147433,  0.33414233,  0.13847757,  0.31316325,  0.04371212,
         0.36556674,  0.56316862,  0.66761528,  0.02491041,  0.12124478]
     )
     self.normal_arr = np.array([
         215.74655491,  237.0905247 ,  193.72021408,  152.89363815,
         175.36670032,  232.59086085,  204.20219942,  248.99321897,
         267.95686148,  165.7204985 ,  177.38110221,  220.40618705,
         262.71893125,  240.00774431,  210.85572027,  255.06583994,
         232.85274614,  274.71932373,  186.83175676,  241.47832856,
         294.98781486,  190.82037054,  143.7991682 ,  170.32090888,
         207.20320791,  208.10226642,  187.09923858,  178.9242382 ,
         155.17266333,  140.69923988,  210.80029533,  193.85525698,
         232.69854217,  230.4408611 ,  149.34523942,  303.6243051 ,
         171.1562868 ,  185.24131426,  195.80616026,  224.38213062,
         261.77203837,  170.81218927,  216.37943211,  265.25650174,
         203.3098626 ,  229.84982086,  212.14777791,  265.25335911,
         296.11334434,  242.40424522,  270.30264815,   77.97401496,
         176.80382943,  156.35135782,  155.29031942,  262.11885208,
         161.33251252,  256.05120377,  158.32542953,  189.07183278,
         155.72524265,  244.68956731,  286.68689241,   94.08648606,
         253.80300049,  161.17371005,  116.94584491,  182.88557535,
         182.85752412,  253.42111371,  131.25146323,  264.86407965,
         197.3742505 ,  296.95506279,  221.01600673,  234.04694958,
         154.42957223,  176.94139196,  200.59554949,  170.4040058 ,
         229.39358115,  127.43357367,  249.09735255,  227.90731765,
         238.9667355 ,  163.83410357,  194.88998826,  134.49013182,
         154.54356067,  254.19699384,  143.93816979,  256.11031829,
         186.56096688,  178.40462838,  159.79032932,  187.7542398 ,
         267.18537402,  190.99969385,  130.30080584,  216.12902248,
         247.8707783 ,  246.49016072,  275.3636918 ,  165.69987612,
         181.16709806,  193.87951446,  156.03720504,  221.44032879,
         182.21405831,  119.22571297,  219.14946203,  140.358539  ,
         210.5826685 ,  256.57132523,  244.82587339,  153.26377344,
         198.44006972,  172.6057332 ,  140.26518016,  171.32162943]
     )
     self.wt = np.array(([1.0] * 60) + ([1.2] * 60))
     ndarr = np.array(([2, 3] * 20) + ([2, 5] * 20) + ([3, 4] * 19) + [2, 3] + [1, 2, 5] * 80 + [1, 2, 3] * 40).reshape(120, 4)
     self.ndarr = [CHAID.NominalColumn(ndarr[:, i]) for i in range(ndarr.shape[1])]
     self.stats_random_data = CHAID.Stats(0.5, 10, .95, self.random_arr)
     self.stats_normal_data = CHAID.Stats(0.5, 10, .95, self.normal_arr)