예제 #1
0
class LogRegTests(unittest.TestCase):
    def setUp(self):
        self.stat_calc = Identity(degree=2, cross=False)
        self.distancefunc = LogReg(self.stat_calc, seed=1)
        self.rng = np.random.RandomState(1)

    def test_distance(self):
        d1 = 0.5 * self.rng.randn(100, 2) - 10
        d2 = 0.5 * self.rng.randn(100, 2) + 10

        d1 = d1.tolist()
        d2 = d2.tolist()

        # Checks whether wrong input type produces error message
        self.assertRaises(TypeError, self.distancefunc.distance, 3.4, d2)
        self.assertRaises(TypeError, self.distancefunc.distance, d1, 3.4)

        # completely separable datasets should have a distance of 1.0
        self.assertEqual(self.distancefunc.distance(d1, d2), 1.0)

        # equal data sets should have a distance of 0.0
        self.assertEqual(self.distancefunc.distance(d1, d1), 0.0)

    def test_dist_max(self):
        self.assertTrue(self.distancefunc.dist_max() == 1.0)
예제 #2
0
class LogRegTests(unittest.TestCase):
    def setUp(self):
        self.stat_calc = Identity(degree = 1, cross = 0)
        self.distancefunc = LogReg(self.stat_calc)
        
    def test_distance(self):
        d1 = 0.5 * np.random.randn(100,2) - 10
        d2 = 0.5 * np.random.randn(100,2) + 10
        
        #Checks whether wrong input type produces error message
        self.assertRaises(TypeError, self.distancefunc.distance, 3.4, d2)
        self.assertRaises(TypeError, self.distancefunc.distance, d1, 3.4)
        
        # completely separable datasets should have a distance of 1.0
        self.assertEqual(self.distancefunc.distance(list(d1),list(d2)), 1.0)

        # equal data sets should have a distance of 0.0
        self.assertEqual(self.distancefunc.distance(list(d1),list(d1)), 0.0)
        
    def test_dist_max(self):
        self.assertTrue(self.distancefunc.dist_max() == 1.0)