Exemplo n.º 1
0
    def test_solver_uses_policy_and_data(self):
        """Test that the solver is passed the data and policy."""

        data = [10]
        initial_policy = Policy(FakeBasis(1))

        solver_stub = SolverParamStub(data, initial_policy)

        lspi.learn(solver_stub.data,
                   solver_stub.policy,
                   solver_stub,
                   max_iterations=1)
Exemplo n.º 2
0
class TestFakeBasis(TestCase):
    def setUp(self):
        self.basis = FakeBasis(6)

    def test_num_actions_property(self):
        self.assertEqual(self.basis.num_actions, 6)

    def test_num_actions_setter(self):
        self.basis.num_actions = 10

        self.assertEqual(self.basis.num_actions, 10)

    def test_num_actions_setter_invalid_value(self):
        with self.assertRaises(ValueError):
            self.basis.num_actions = 0

    def test_size(self):
        self.assertEqual(self.basis.size(), 1)

    def test_evaluate(self):
        np.testing.assert_array_almost_equal(self.basis.evaluate(None, 0),
                                             np.array([1.]))

    def test_evaluate_negative_action_index(self):
        with self.assertRaises(IndexError):
            self.basis.evaluate(None, -1)

    def test_evaluate_out_of_bounds_action_index(self):
        with self.assertRaises(IndexError):
            self.basis.evaluate(None, 6)
Exemplo n.º 3
0
    def test_epsilon_stopping_condition(self):
        """Test if learning stops when distance is less than epsilon."""

        with self.assertRaises(ValueError):
            lspi.learn(None, None, None, epsilon=0)

        epsilon_solver = EpsilonSolverStub(10**-21)

        lspi.learn(None,
                   Policy(FakeBasis(1)),
                   epsilon_solver,
                   epsilon=10**-20,
                   max_iterations=1000)

        self.assertEqual(epsilon_solver.num_calls, 1)
Exemplo n.º 4
0
    def test_max_iterations_stopping_condition(self):
        """Test if learning stops when max_iterations is reached."""

        with self.assertRaises(ValueError):
            lspi.learn(None, None, None, max_iterations=0)

        max_iterations_solver = MaxIterationsSolverStub()

        lspi.learn(None,
                   Policy(FakeBasis(1)),
                   max_iterations_solver,
                   epsilon=10**-200,
                   max_iterations=10)

        self.assertEqual(max_iterations_solver.num_calls, 10)
Exemplo n.º 5
0
    def test_returns_policy_with_new_weights(self):
        """Test if the weights in the new policy differ and are not the same underlying numpy vector."""

        initial_policy = Policy(FakeBasis(1))

        weight_solver = WeightSolverStub(initial_policy.weights)

        new_policy = lspi.learn(None,
                                initial_policy,
                                weight_solver,
                                max_iterations=1)

        self.assertEqual(weight_solver.num_calls, 1)
        self.assertFalse(
            np.may_share_memory(initial_policy.weights, new_policy))
        self.assertNotEquals(id(initial_policy), id(new_policy))
        np.testing.assert_array_almost_equal(new_policy.weights,
                                             weight_solver.weights)
 def create_policy(self, *args, **kwargs):
     return Policy(FakeBasis(5), *args, **kwargs)
Exemplo n.º 7
0
 def setUp(self):
     self.basis = FakeBasis(6)