class TestSoftmax(unittest.TestCase): def setUp(self): torch.manual_seed(2) self.model = nn.Sequential( nn.Linear(STATE_DIM, ACTIONS) ) optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1) self.policy = SoftmaxPolicy(self.model, optimizer) def test_run(self): state1 = State(torch.randn(1, STATE_DIM)) dist1 = self.policy(state1) action1 = dist1.sample() log_prob1 = dist1.log_prob(action1) self.assertEqual(action1.item(), 0) state2 = State(torch.randn(1, STATE_DIM)) dist2 = self.policy(state2) action2 = dist2.sample() log_prob2 = dist2.log_prob(action2) self.assertEqual(action2.item(), 2) loss = -(torch.tensor([-1, 1000000]) * torch.cat((log_prob1, log_prob2))).mean() self.policy.reinforce(loss) state3 = State(torch.randn(1, STATE_DIM)) dist3 = self.policy(state3) action3 = dist3.sample() self.assertEqual(action3.item(), 2) def test_multi_action(self): states = State(torch.randn(3, STATE_DIM)) actions = self.policy(states).sample() tt.assert_equal(actions, torch.tensor([2, 2, 0])) def test_list(self): torch.manual_seed(1) states = State(torch.randn(3, STATE_DIM), torch.tensor([1, 0, 1])) dist = self.policy(states) actions = dist.sample() log_probs = dist.log_prob(actions) tt.assert_equal(actions, torch.tensor([1, 2, 1])) loss = -(torch.tensor([[1, 2, 3]]) * log_probs).mean() self.policy.reinforce(loss) def test_reinforce(self): def loss(log_probs): return -log_probs.mean() states = State(torch.randn(3, STATE_DIM), torch.tensor([1, 1, 1])) actions = self.policy.eval(states).sample() # notice the values increase with each successive reinforce log_probs = self.policy(states).log_prob(actions) tt.assert_almost_equal(log_probs, torch.tensor([-0.84, -0.62, -0.757]), decimal=3) self.policy.reinforce(loss(log_probs)) log_probs = self.policy(states).log_prob(actions) tt.assert_almost_equal(log_probs, torch.tensor([-0.811, -0.561, -0.701]), decimal=3) self.policy.reinforce(loss(log_probs)) log_probs = self.policy(states).log_prob(actions) tt.assert_almost_equal(log_probs, torch.tensor([-0.785, -0.51, -0.651]), decimal=3)
class TestSoftmax(unittest.TestCase): def setUp(self): torch.manual_seed(2) self.model = nn.Sequential(nn.Linear(STATE_DIM, ACTIONS)) optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1) self.policy = SoftmaxPolicy(self.model, optimizer, ACTIONS) def test_run(self): state = State(torch.randn(1, STATE_DIM)) action = self.policy(state) self.assertEqual(action.item(), 0) state = State(torch.randn(1, STATE_DIM)) action = self.policy(state) self.assertEqual(action.item(), 2) self.policy.reinforce(torch.tensor([-1, 1000000]).float()) action = self.policy(state) self.assertEqual(action.item(), 2) def test_multi_action(self): states = State(torch.randn(3, STATE_DIM)) actions = self.policy(states) tt.assert_equal(actions, torch.tensor([2, 2, 0])) self.policy.reinforce(torch.tensor([[1, 2, 3]]).float()) def test_multi_batch_reinforce(self): self.policy(State(torch.randn(2, STATE_DIM))) self.policy(State(torch.randn(2, STATE_DIM))) self.policy(State(torch.randn(2, STATE_DIM))) self.policy.reinforce(torch.tensor([1, 2, 3, 4]).float()) self.policy.reinforce(torch.tensor([1, 2]).float()) with self.assertRaises(Exception): self.policy.reinforce(torch.tensor([1, 2]).float()) def test_list(self): torch.manual_seed(1) states = State(torch.randn(3, STATE_DIM), torch.tensor([1, 0, 1])) actions = self.policy(states) tt.assert_equal(actions, torch.tensor([1, 2, 1])) self.policy.reinforce(torch.tensor([[1, 2, 3]]).float()) def test_action_prob(self): torch.manual_seed(1) states = State(torch.randn(3, STATE_DIM), torch.tensor([1, 0, 1])) with torch.no_grad(): actions = self.policy(states) log_probs = self.policy(states, action=actions) tt.assert_almost_equal(log_probs, torch.tensor([-1.59, -1.099, -1.528]), decimal=3) def test_custom_loss(self): def loss(log_probs): return -log_probs.mean() states = State(torch.randn(3, STATE_DIM), torch.tensor([1, 1, 1])) actions = self.policy.eval(states) # notice the values increase with each successive reinforce log_probs = self.policy(states, actions) tt.assert_almost_equal(log_probs, torch.tensor([-0.84, -0.62, -0.757]), decimal=3) self.policy.reinforce(loss) log_probs = self.policy(states, actions) tt.assert_almost_equal(log_probs, torch.tensor([-0.811, -0.561, -0.701]), decimal=3) self.policy.reinforce(loss) log_probs = self.policy(states, actions) tt.assert_almost_equal(log_probs, torch.tensor([-0.785, -0.51, -0.651]), decimal=3)