Ejemplo n.º 1
0
class TestGaussian(unittest.TestCase):
    def setUp(self):
        torch.manual_seed(2)
        self.model = nn.Sequential(nn.Linear(STATE_DIM, ACTION_DIM * 2))
        optimizer = torch.optim.RMSprop(self.model.parameters(), lr=0.01)
        self.policy = GaussianPolicy(self.model, optimizer, ACTION_DIM)

    def test_output_shape(self):
        state = State(torch.randn(1, STATE_DIM))
        action = self.policy(state)
        self.assertEqual(action.shape, (1, ACTION_DIM))
        state = State(torch.randn(5, STATE_DIM))
        action = self.policy(state)
        self.assertEqual(action.shape, (5, ACTION_DIM))

    def test_reinforce_one(self):
        state = State(torch.randn(1, STATE_DIM))
        self.policy(state)
        self.policy.reinforce(torch.tensor([1]).float())

    def test_converge(self):
        state = State(torch.randn(1, STATE_DIM))
        target = torch.tensor([1., 2., -1.])

        for _ in range(0, 1000):
            action = self.policy(state)
            loss = torch.abs(target - action).mean()
            self.policy.reinforce(-loss)

        self.assertTrue(loss < 1)
class TestGaussian(unittest.TestCase):
    def setUp(self):
        torch.manual_seed(2)
        self.space = Box(np.array([-1, -1, -1]), np.array([1, 1, 1]))
        self.model = nn.Sequential(nn.Linear(STATE_DIM, ACTION_DIM * 2))
        optimizer = torch.optim.RMSprop(self.model.parameters(), lr=0.01)
        self.policy = GaussianPolicy(self.model,
                                     optimizer,
                                     self.space,
                                     checkpointer=DummyCheckpointer())

    def test_output_shape(self):
        state = State(torch.randn(1, STATE_DIM))
        action = self.policy(state).sample()
        self.assertEqual(action.shape, (1, ACTION_DIM))
        state = State(torch.randn(5, STATE_DIM))
        action = self.policy(state).sample()
        self.assertEqual(action.shape, (5, ACTION_DIM))

    def test_reinforce_one(self):
        state = State(torch.randn(1, STATE_DIM))
        dist = self.policy(state)
        action = dist.sample()
        log_prob1 = dist.log_prob(action)
        loss = -log_prob1.mean()
        self.policy.reinforce(loss)

        dist = self.policy(state)
        log_prob2 = dist.log_prob(action)

        self.assertGreater(log_prob2.item(), log_prob1.item())

    def test_converge(self):
        state = State(torch.randn(1, STATE_DIM))
        target = torch.tensor([1., 2., -1.])

        for _ in range(0, 1000):
            dist = self.policy(state)
            action = dist.sample()
            log_prob = dist.log_prob(action)
            error = ((target - action)**2).mean()
            loss = (error * log_prob).mean()
            self.policy.reinforce(loss)

        self.assertTrue(error < 1)

    def test_eval(self):
        state = State(torch.randn(1, STATE_DIM))
        dist = self.policy.no_grad(state)
        tt.assert_almost_equal(dist.mean,
                               torch.tensor([[-0.233, 0.459, -0.058]]),
                               decimal=3)
        tt.assert_almost_equal(dist.entropy(),
                               torch.tensor([4.251]),
                               decimal=3)
        best = self.policy.eval(state)
        tt.assert_almost_equal(best,
                               torch.tensor([[-0.233, 0.459, -0.058]]),
                               decimal=3)