Example #1
0
class TestSoftmax(unittest.TestCase):
    def setUp(self):
        torch.manual_seed(2)
        self.model = nn.Sequential(
            nn.Linear(STATE_DIM, ACTIONS)
        )
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
        self.policy = SoftmaxPolicy(self.model, optimizer)

    def test_run(self):
        state1 = State(torch.randn(1, STATE_DIM))
        dist1 = self.policy(state1)
        action1 = dist1.sample()
        log_prob1 = dist1.log_prob(action1)
        self.assertEqual(action1.item(), 0)

        state2 = State(torch.randn(1, STATE_DIM))
        dist2 = self.policy(state2)
        action2 = dist2.sample()
        log_prob2 = dist2.log_prob(action2)
        self.assertEqual(action2.item(), 2)

        loss = -(torch.tensor([-1, 1000000]) * torch.cat((log_prob1, log_prob2))).mean()
        self.policy.reinforce(loss)

        state3 = State(torch.randn(1, STATE_DIM))
        dist3 = self.policy(state3)
        action3 = dist3.sample()
        self.assertEqual(action3.item(), 2)

    def test_multi_action(self):
        states = State(torch.randn(3, STATE_DIM))
        actions = self.policy(states).sample()
        tt.assert_equal(actions, torch.tensor([2, 2, 0]))

    def test_list(self):
        torch.manual_seed(1)
        states = State(torch.randn(3, STATE_DIM), torch.tensor([1, 0, 1]))
        dist = self.policy(states)
        actions = dist.sample()
        log_probs = dist.log_prob(actions)
        tt.assert_equal(actions, torch.tensor([1, 2, 1]))
        loss = -(torch.tensor([[1, 2, 3]]) * log_probs).mean()
        self.policy.reinforce(loss)

    def test_reinforce(self):
        def loss(log_probs):
            return -log_probs.mean()

        states = State(torch.randn(3, STATE_DIM), torch.tensor([1, 1, 1]))
        actions = self.policy.eval(states).sample()

        # notice the values increase with each successive reinforce
        log_probs = self.policy(states).log_prob(actions)
        tt.assert_almost_equal(log_probs, torch.tensor([-0.84, -0.62, -0.757]), decimal=3)
        self.policy.reinforce(loss(log_probs))
        log_probs = self.policy(states).log_prob(actions)
        tt.assert_almost_equal(log_probs, torch.tensor([-0.811, -0.561, -0.701]), decimal=3)
        self.policy.reinforce(loss(log_probs))
        log_probs = self.policy(states).log_prob(actions)
        tt.assert_almost_equal(log_probs, torch.tensor([-0.785, -0.51, -0.651]), decimal=3)
Example #2
0
class TestSoftmax(unittest.TestCase):
    def setUp(self):
        torch.manual_seed(2)
        self.model = nn.Sequential(nn.Linear(STATE_DIM, ACTIONS))
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
        self.policy = SoftmaxPolicy(self.model, optimizer, ACTIONS)

    def test_run(self):
        state = State(torch.randn(1, STATE_DIM))
        action = self.policy(state)
        self.assertEqual(action.item(), 0)
        state = State(torch.randn(1, STATE_DIM))
        action = self.policy(state)
        self.assertEqual(action.item(), 2)
        self.policy.reinforce(torch.tensor([-1, 1000000]).float())
        action = self.policy(state)
        self.assertEqual(action.item(), 2)

    def test_multi_action(self):
        states = State(torch.randn(3, STATE_DIM))
        actions = self.policy(states)
        tt.assert_equal(actions, torch.tensor([2, 2, 0]))
        self.policy.reinforce(torch.tensor([[1, 2, 3]]).float())

    def test_multi_batch_reinforce(self):
        self.policy(State(torch.randn(2, STATE_DIM)))
        self.policy(State(torch.randn(2, STATE_DIM)))
        self.policy(State(torch.randn(2, STATE_DIM)))
        self.policy.reinforce(torch.tensor([1, 2, 3, 4]).float())
        self.policy.reinforce(torch.tensor([1, 2]).float())
        with self.assertRaises(Exception):
            self.policy.reinforce(torch.tensor([1, 2]).float())

    def test_list(self):
        torch.manual_seed(1)
        states = State(torch.randn(3, STATE_DIM), torch.tensor([1, 0, 1]))
        actions = self.policy(states)
        tt.assert_equal(actions, torch.tensor([1, 2, 1]))
        self.policy.reinforce(torch.tensor([[1, 2, 3]]).float())

    def test_action_prob(self):
        torch.manual_seed(1)
        states = State(torch.randn(3, STATE_DIM), torch.tensor([1, 0, 1]))
        with torch.no_grad():
            actions = self.policy(states)
        log_probs = self.policy(states, action=actions)
        tt.assert_almost_equal(log_probs,
                               torch.tensor([-1.59, -1.099, -1.528]),
                               decimal=3)

    def test_custom_loss(self):
        def loss(log_probs):
            return -log_probs.mean()

        states = State(torch.randn(3, STATE_DIM), torch.tensor([1, 1, 1]))
        actions = self.policy.eval(states)

        # notice the values increase with each successive reinforce
        log_probs = self.policy(states, actions)
        tt.assert_almost_equal(log_probs,
                               torch.tensor([-0.84, -0.62, -0.757]),
                               decimal=3)
        self.policy.reinforce(loss)
        log_probs = self.policy(states, actions)
        tt.assert_almost_equal(log_probs,
                               torch.tensor([-0.811, -0.561, -0.701]),
                               decimal=3)
        self.policy.reinforce(loss)
        log_probs = self.policy(states, actions)
        tt.assert_almost_equal(log_probs,
                               torch.tensor([-0.785, -0.51, -0.651]),
                               decimal=3)