コード例 #1
0
ファイル: test_simplex.py プロジェクト: andrewhead/fabexample
class UpdatePointsTest(unittest.TestCase):
    def setUp(self):
        self.simplex = Simplex()
        self.rank_func = lambda X: np.array([x[0] for x in X])

    def test_expand_points(self):
        vertices = self.simplex.update_points(np.array([[1.0, 0], [2.0, 0], [3.1, 0]]), self.rank_func)
        self.assertTrue(np.all(vertices - np.array([[1, 0], [2, 0], [-0.1, 0]]) < 0.0001))

    def test_reflect_points(self):
        vertices = self.simplex.update_points(np.array([[1.0, 0], [2.0, 0], [2.9, 0]]), self.rank_func)
        self.assertTrue(np.all(vertices - np.array([[1.0, 0], [2.0, 0], [float(31) / 30, 0]]) < 0.0001))

    def test_contract_points(self):
        vertices = self.simplex.update_points(
            np.array([[1.0, 0], [2.0, 0], [6.0, 0]]),
            lambda X: np.array([2, 3, 4, 5, 6, 1]) if len(X) == 6 else np.array([1, 2, 3]),
        )
        self.assertTrue(np.all(vertices - np.array([[1, 0], [2, 0], [4.5, 0]]) < 0.0001))

    def test_reduce_all(self):
        vertices = self.simplex.update_points(
            np.array([[1.0, 0], [2.0, 0], [3.0, 0]]),
            lambda X: np.array([1, 2, 3, 4, 5, 6]) if len(X) == 6 else np.array([1, 2, 3]),
        )
        self.assertTrue(np.all(vertices - np.array([[1, 0], [1.5, 0], [2, 0]]) < 0.0001))