class TestQDist(unittest.TestCase): def setUp(self): torch.manual_seed(2) self.model = nn.Sequential(nn.Linear(STATE_DIM, ACTIONS * ATOMS)) self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1) self.q = QDist(self.model, self.optimizer, ACTIONS, ATOMS, V_MIN, V_MAX) def test_atoms(self): tt.assert_almost_equal(self.q.atoms, torch.tensor([-2, -1, 0, 1, 2])) def test_q_values(self): states = StateArray(torch.randn((3, STATE_DIM)), (3,)) probs = self.q(states) self.assertEqual(probs.shape, (3, ACTIONS, ATOMS)) tt.assert_almost_equal( probs.sum(dim=2), torch.tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]), decimal=3, ) tt.assert_almost_equal( probs, torch.tensor( [ [ [0.2065, 0.1045, 0.1542, 0.2834, 0.2513], [0.3903, 0.2471, 0.0360, 0.1733, 0.1533], ], [ [0.1966, 0.1299, 0.1431, 0.3167, 0.2137], [0.3190, 0.2471, 0.0534, 0.1424, 0.2380], ], [ [0.1427, 0.2486, 0.0946, 0.4112, 0.1029], [0.0819, 0.1320, 0.1203, 0.0373, 0.6285], ], ] ), decimal=3, ) def test_single_q_values(self): states = StateArray(torch.randn((3, STATE_DIM)), (3,)) actions = torch.tensor([0, 1, 0]) probs = self.q(states, actions) self.assertEqual(probs.shape, (3, ATOMS)) tt.assert_almost_equal( probs.sum(dim=1), torch.tensor([1.0, 1.0, 1.0]), decimal=3 ) tt.assert_almost_equal( probs, torch.tensor( [ [0.2065, 0.1045, 0.1542, 0.2834, 0.2513], [0.3190, 0.2471, 0.0534, 0.1424, 0.2380], [0.1427, 0.2486, 0.0946, 0.4112, 0.1029], ] ), decimal=3, ) def test_done(self): states = StateArray(torch.randn((3, STATE_DIM)), (3,), mask=torch.tensor([1, 0, 1])) probs = self.q(states) self.assertEqual(probs.shape, (3, ACTIONS, ATOMS)) tt.assert_almost_equal( probs.sum(dim=2), torch.tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]), decimal=3, ) tt.assert_almost_equal( probs, torch.tensor( [ [ [0.2065, 0.1045, 0.1542, 0.2834, 0.2513], [0.3903, 0.2471, 0.0360, 0.1733, 0.1533], ], [[0, 0, 1, 0, 0], [0, 0, 1, 0, 0]], [ [0.1427, 0.2486, 0.0946, 0.4112, 0.1029], [0.0819, 0.1320, 0.1203, 0.0373, 0.6285], ], ] ), decimal=3, ) def test_reinforce(self): states = StateArray(torch.randn((3, STATE_DIM)), (3,)) actions = torch.tensor([0, 1, 0]) original_probs = self.q(states, actions) tt.assert_almost_equal( original_probs, torch.tensor( [ [0.2065, 0.1045, 0.1542, 0.2834, 0.2513], [0.3190, 0.2471, 0.0534, 0.1424, 0.2380], [0.1427, 0.2486, 0.0946, 0.4112, 0.1029], ] ), decimal=3, ) target_dists = torch.tensor( [[0, 0, 1, 0, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0]] ).float() def _loss(dist, target_dist): log_dist = torch.log(torch.clamp(dist, min=1e-5)) log_target_dist = torch.log(torch.clamp(target_dist, min=1e-5)) return (target_dist * (log_target_dist - log_dist)).sum(dim=-1).mean() self.q.reinforce(_loss(original_probs, target_dists)) new_probs = self.q(states, actions) tt.assert_almost_equal( torch.sign(new_probs - original_probs), torch.sign(target_dists - 0.5) ) def test_project_dist(self): # This gave problems in the past between different cuda version, # so a test was added. q = QDist(self.model, self.optimizer, ACTIONS, 51, -10., 10.) dist = torch.tensor([ [0.0190, 0.0197, 0.0200, 0.0190, 0.0195, 0.0198, 0.0194, 0.0192, 0.0201, 0.0203, 0.0189, 0.0190, 0.0199, 0.0193, 0.0192, 0.0199, 0.0198, 0.0197, 0.0193, 0.0198, 0.0192, 0.0191, 0.0200, 0.0202, 0.0191, 0.0202, 0.0198, 0.0200, 0.0198, 0.0193, 0.0192, 0.0202, 0.0192, 0.0194, 0.0199, 0.0197, 0.0197, 0.0201, 0.0199, 0.0190, 0.0192, 0.0195, 0.0202, 0.0194, 0.0203, 0.0201, 0.0190, 0.0192, 0.0201, 0.0201, 0.0192], [0.0191, 0.0197, 0.0200, 0.0190, 0.0195, 0.0198, 0.0194, 0.0192, 0.0201, 0.0203, 0.0190, 0.0190, 0.0199, 0.0193, 0.0192, 0.0199, 0.0198, 0.0197, 0.0193, 0.0198, 0.0192, 0.0191, 0.0200, 0.0202, 0.0191, 0.0202, 0.0198, 0.0200, 0.0198, 0.0193, 0.0192, 0.0202, 0.0192, 0.0194, 0.0199, 0.0197, 0.0197, 0.0200, 0.0199, 0.0190, 0.0192, 0.0195, 0.0202, 0.0194, 0.0203, 0.0201, 0.0190, 0.0192, 0.0201, 0.0200, 0.0192], [0.0191, 0.0197, 0.0200, 0.0190, 0.0195, 0.0198, 0.0194, 0.0192, 0.0200, 0.0203, 0.0190, 0.0191, 0.0199, 0.0193, 0.0192, 0.0199, 0.0198, 0.0197, 0.0193, 0.0198, 0.0192, 0.0191, 0.0199, 0.0202, 0.0192, 0.0202, 0.0198, 0.0200, 0.0198, 0.0193, 0.0192, 0.0202, 0.0192, 0.0194, 0.0199, 0.0197, 0.0197, 0.0200, 0.0199, 0.0190, 0.0192, 0.0195, 0.0202, 0.0194, 0.0203, 0.0201, 0.0190, 0.0192, 0.0201, 0.0200, 0.0192] ]) support = torch.tensor([ [-9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624, -7.3743, -6.9862, -6.5980, -6.2099, -5.8218, -5.4337, -5.0456, -4.6574, -4.2693, -3.8812, -3.4931, -3.1050, -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762, -0.3881, 0.0000, 0.3881, 0.7762, 1.1644, 1.5525, 1.9406, 2.3287, 2.7168, 3.1050, 3.4931, 3.8812, 4.2693, 4.6574, 5.0456, 5.4337, 5.8218, 6.2099, 6.5980, 6.9862, 7.3743, 7.7624, 8.1505, 8.5386, 8.9268, 9.3149, 9.7030], [-9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624, -7.3743, -6.9862, -6.5980, -6.2099, -5.8218, -5.4337, -5.0456, -4.6574, -4.2693, -3.8812, -3.4931, -3.1050, -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762, -0.3881, 0.0000, 0.3881, 0.7762, 1.1644, 1.5525, 1.9406, 2.3287, 2.7168, 3.1050, 3.4931, 3.8812, 4.2693, 4.6574, 5.0456, 5.4337, 5.8218, 6.2099, 6.5980, 6.9862, 7.3743, 7.7624, 8.1505, 8.5386, 8.9268, 9.3149, 9.7030], [-9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624, -7.3743, -6.9862, -6.5980, -6.2099, -5.8218, -5.4337, -5.0456, -4.6574, -4.2693, -3.8812, -3.4931, -3.1050, -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762, -0.3881, 0.0000, 0.3881, 0.7762, 1.1644, 1.5525, 1.9406, 2.3287, 2.7168, 3.1050, 3.4931, 3.8812, 4.2693, 4.6574, 5.0456, 5.4337, 5.8218, 6.2099, 6.5980, 6.9862, 7.3743, 7.7624, 8.1505, 8.5386, 8.9268, 9.3149, 9.7030] ]) expected = torch.tensor([ [0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202, 0.0199, 0.0202, 0.0208, 0.0201, 0.0195, 0.0201, 0.0201, 0.0198, 0.0203, 0.0204, 0.0203, 0.0200, 0.0203, 0.0199, 0.0197, 0.0205, 0.0208, 0.0197, 0.0214, 0.0204, 0.0206, 0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204, 0.0203, 0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204, 0.0204, 0.0205, 0.0208, 0.0200, 0.0197, 0.0204, 0.0207, 0.0200, 0.0049], [0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202, 0.0199, 0.0202, 0.0208, 0.0202, 0.0196, 0.0201, 0.0201, 0.0198, 0.0203, 0.0204, 0.0203, 0.0200, 0.0203, 0.0199, 0.0197, 0.0205, 0.0208, 0.0197, 0.0214, 0.0204, 0.0206, 0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204, 0.0203, 0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204, 0.0204, 0.0205, 0.0208, 0.0200, 0.0197, 0.0204, 0.0206, 0.0200, 0.0049], [0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202, 0.0199, 0.0202, 0.0208, 0.0202, 0.0196, 0.0202, 0.0201, 0.0198, 0.0203, 0.0204, 0.0203, 0.0200, 0.0203, 0.0199, 0.0197, 0.0204, 0.0208, 0.0198, 0.0214, 0.0204, 0.0206, 0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204, 0.0203, 0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204, 0.0204, 0.0205, 0.0208, 0.0200, 0.0197, 0.0204, 0.0206, 0.0200, 0.0049] ]) tt.assert_almost_equal(q.project(dist, support).cpu(), expected.cpu(), decimal=3) def test_project_dist_cuda(self): if torch.cuda.is_available(): # This gave problems in the past between different cuda version, # so a test was added. q = QDist(self.model.cuda(), self.optimizer, ACTIONS, 51, -10., 10.) dist = torch.tensor([ [0.0190, 0.0197, 0.0200, 0.0190, 0.0195, 0.0198, 0.0194, 0.0192, 0.0201, 0.0203, 0.0189, 0.0190, 0.0199, 0.0193, 0.0192, 0.0199, 0.0198, 0.0197, 0.0193, 0.0198, 0.0192, 0.0191, 0.0200, 0.0202, 0.0191, 0.0202, 0.0198, 0.0200, 0.0198, 0.0193, 0.0192, 0.0202, 0.0192, 0.0194, 0.0199, 0.0197, 0.0197, 0.0201, 0.0199, 0.0190, 0.0192, 0.0195, 0.0202, 0.0194, 0.0203, 0.0201, 0.0190, 0.0192, 0.0201, 0.0201, 0.0192], [0.0191, 0.0197, 0.0200, 0.0190, 0.0195, 0.0198, 0.0194, 0.0192, 0.0201, 0.0203, 0.0190, 0.0190, 0.0199, 0.0193, 0.0192, 0.0199, 0.0198, 0.0197, 0.0193, 0.0198, 0.0192, 0.0191, 0.0200, 0.0202, 0.0191, 0.0202, 0.0198, 0.0200, 0.0198, 0.0193, 0.0192, 0.0202, 0.0192, 0.0194, 0.0199, 0.0197, 0.0197, 0.0200, 0.0199, 0.0190, 0.0192, 0.0195, 0.0202, 0.0194, 0.0203, 0.0201, 0.0190, 0.0192, 0.0201, 0.0200, 0.0192], [0.0191, 0.0197, 0.0200, 0.0190, 0.0195, 0.0198, 0.0194, 0.0192, 0.0200, 0.0203, 0.0190, 0.0191, 0.0199, 0.0193, 0.0192, 0.0199, 0.0198, 0.0197, 0.0193, 0.0198, 0.0192, 0.0191, 0.0199, 0.0202, 0.0192, 0.0202, 0.0198, 0.0200, 0.0198, 0.0193, 0.0192, 0.0202, 0.0192, 0.0194, 0.0199, 0.0197, 0.0197, 0.0200, 0.0199, 0.0190, 0.0192, 0.0195, 0.0202, 0.0194, 0.0203, 0.0201, 0.0190, 0.0192, 0.0201, 0.0200, 0.0192] ]).cuda() support = torch.tensor([ [-9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624, -7.3743, -6.9862, -6.5980, -6.2099, -5.8218, -5.4337, -5.0456, -4.6574, -4.2693, -3.8812, -3.4931, -3.1050, -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762, -0.3881, 0.0000, 0.3881, 0.7762, 1.1644, 1.5525, 1.9406, 2.3287, 2.7168, 3.1050, 3.4931, 3.8812, 4.2693, 4.6574, 5.0456, 5.4337, 5.8218, 6.2099, 6.5980, 6.9862, 7.3743, 7.7624, 8.1505, 8.5386, 8.9268, 9.3149, 9.7030], [-9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624, -7.3743, -6.9862, -6.5980, -6.2099, -5.8218, -5.4337, -5.0456, -4.6574, -4.2693, -3.8812, -3.4931, -3.1050, -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762, -0.3881, 0.0000, 0.3881, 0.7762, 1.1644, 1.5525, 1.9406, 2.3287, 2.7168, 3.1050, 3.4931, 3.8812, 4.2693, 4.6574, 5.0456, 5.4337, 5.8218, 6.2099, 6.5980, 6.9862, 7.3743, 7.7624, 8.1505, 8.5386, 8.9268, 9.3149, 9.7030], [-9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624, -7.3743, -6.9862, -6.5980, -6.2099, -5.8218, -5.4337, -5.0456, -4.6574, -4.2693, -3.8812, -3.4931, -3.1050, -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762, -0.3881, 0.0000, 0.3881, 0.7762, 1.1644, 1.5525, 1.9406, 2.3287, 2.7168, 3.1050, 3.4931, 3.8812, 4.2693, 4.6574, 5.0456, 5.4337, 5.8218, 6.2099, 6.5980, 6.9862, 7.3743, 7.7624, 8.1505, 8.5386, 8.9268, 9.3149, 9.7030] ]).cuda() expected = torch.tensor([ [0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202, 0.0199, 0.0202, 0.0208, 0.0201, 0.0195, 0.0201, 0.0201, 0.0198, 0.0203, 0.0204, 0.0203, 0.0200, 0.0203, 0.0199, 0.0197, 0.0205, 0.0208, 0.0197, 0.0214, 0.0204, 0.0206, 0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204, 0.0203, 0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204, 0.0204, 0.0205, 0.0208, 0.0200, 0.0197, 0.0204, 0.0207, 0.0200, 0.0049], [0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202, 0.0199, 0.0202, 0.0208, 0.0202, 0.0196, 0.0201, 0.0201, 0.0198, 0.0203, 0.0204, 0.0203, 0.0200, 0.0203, 0.0199, 0.0197, 0.0205, 0.0208, 0.0197, 0.0214, 0.0204, 0.0206, 0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204, 0.0203, 0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204, 0.0204, 0.0205, 0.0208, 0.0200, 0.0197, 0.0204, 0.0206, 0.0200, 0.0049], [0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202, 0.0199, 0.0202, 0.0208, 0.0202, 0.0196, 0.0202, 0.0201, 0.0198, 0.0203, 0.0204, 0.0203, 0.0200, 0.0203, 0.0199, 0.0197, 0.0204, 0.0208, 0.0198, 0.0214, 0.0204, 0.0206, 0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204, 0.0203, 0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204, 0.0204, 0.0205, 0.0208, 0.0200, 0.0197, 0.0204, 0.0206, 0.0200, 0.0049] ]) tt.assert_almost_equal(q.project(dist, support).cpu(), expected.cpu(), decimal=3)
class TestQDist(unittest.TestCase): def setUp(self): torch.manual_seed(2) self.model = nn.Sequential(nn.Linear(STATE_DIM, ACTIONS * ATOMS)) optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1) self.q = QDist(self.model, optimizer, ACTIONS, ATOMS, V_MIN, V_MAX) def test_atoms(self): tt.assert_almost_equal(self.q.atoms, torch.tensor([-2, -1, 0, 1, 2])) def test_q_values(self): states = State(torch.randn((3, STATE_DIM))) probs = self.q(states) self.assertEqual(probs.shape, (3, ACTIONS, ATOMS)) tt.assert_almost_equal( probs.sum(dim=2), torch.tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]), decimal=3, ) tt.assert_almost_equal( probs, torch.tensor([ [ [0.2065, 0.1045, 0.1542, 0.2834, 0.2513], [0.3903, 0.2471, 0.0360, 0.1733, 0.1533], ], [ [0.1966, 0.1299, 0.1431, 0.3167, 0.2137], [0.3190, 0.2471, 0.0534, 0.1424, 0.2380], ], [ [0.1427, 0.2486, 0.0946, 0.4112, 0.1029], [0.0819, 0.1320, 0.1203, 0.0373, 0.6285], ], ]), decimal=3, ) def test_single_q_values(self): states = State(torch.randn((3, STATE_DIM))) actions = torch.tensor([0, 1, 0]) probs = self.q(states, actions) self.assertEqual(probs.shape, (3, ATOMS)) tt.assert_almost_equal(probs.sum(dim=1), torch.tensor([1.0, 1.0, 1.0]), decimal=3) tt.assert_almost_equal( probs, torch.tensor([ [0.2065, 0.1045, 0.1542, 0.2834, 0.2513], [0.3190, 0.2471, 0.0534, 0.1424, 0.2380], [0.1427, 0.2486, 0.0946, 0.4112, 0.1029], ]), decimal=3, ) def test_done(self): states = State(torch.randn((3, STATE_DIM)), mask=torch.tensor([1, 0, 1])) probs = self.q(states) self.assertEqual(probs.shape, (3, ACTIONS, ATOMS)) tt.assert_almost_equal( probs.sum(dim=2), torch.tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]), decimal=3, ) tt.assert_almost_equal( probs, torch.tensor([ [ [0.2065, 0.1045, 0.1542, 0.2834, 0.2513], [0.3903, 0.2471, 0.0360, 0.1733, 0.1533], ], [[0, 0, 1, 0, 0], [0, 0, 1, 0, 0]], [ [0.1427, 0.2486, 0.0946, 0.4112, 0.1029], [0.0819, 0.1320, 0.1203, 0.0373, 0.6285], ], ]), decimal=3, ) def test_reinforce(self): states = State(torch.randn((3, STATE_DIM))) actions = torch.tensor([0, 1, 0]) original_probs = self.q(states, actions) tt.assert_almost_equal( original_probs, torch.tensor([ [0.2065, 0.1045, 0.1542, 0.2834, 0.2513], [0.3190, 0.2471, 0.0534, 0.1424, 0.2380], [0.1427, 0.2486, 0.0946, 0.4112, 0.1029], ]), decimal=3, ) target_dists = torch.tensor([[0, 0, 1, 0, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0]]).float() self.q.reinforce(target_dists) new_probs = self.q(states, actions) tt.assert_almost_equal(torch.sign(new_probs - original_probs), torch.sign(target_dists - 0.5))