def assert_states_equal(self, actual, expected):
     tt.assert_almost_equal(actual.raw, expected.raw)
     tt.assert_equal(actual.mask, expected.mask)
Example #2
0
 def test_entropy_batch(self):
     predictions = torch.tensor([[0.3, 0.7], [0.7, 0.3]])
     entropy = metrics.entropy(predictions, None)
     tt.assert_almost_equal(entropy,
                            torch.tensor([0.61086430205, 0.61086430205]))
Example #3
0
 def assert_array_equal(self, actual, expected):
     for first, second in zip(actual, expected):
         if second is None:
             self.assertIsNone(first)
         else:
             tt.assert_almost_equal(first, second, decimal=3)
 def test_with_only_just_almost_equal_one_value_tensors(self):
     a = torch.tensor([1, 2, 23.65727169])
     b = torch.tensor([1, 2, 23.65727160])
     tt.assert_almost_equal(a, b)
 def assert_states_equal(self, actual, expected):
     tt.assert_almost_equal(actual.observation, expected.observation)
     tt.assert_equal(actual.mask, expected.mask)
Example #6
0
 def test_rainbow_model_cuda(self):
     env = AtariEnvironment('Breakout')
     model = nature_rainbow(env).cuda()
     env.reset()
     x = torch.cat([env.state.raw] * 4, dim=1).float().cuda()
     out = model(x)
     tt.assert_almost_equal(
         out.cpu(),
         torch.tensor([[
             -1.4765e-02, -4.0353e-02, -2.1705e-02, -2.2314e-02, 3.6881e-02,
             -1.4175e-02, 1.2442e-02, -6.8713e-03, 2.4970e-02, 2.5681e-02,
             -4.5859e-02, -2.3327e-02, 3.6205e-02, 7.1024e-03, -2.7564e-02,
             2.1592e-02, -3.2728e-02, 1.3602e-02, -1.1690e-02, -4.3082e-02,
             -1.2996e-02, 1.7184e-02, 1.3446e-02, -3.3587e-03, -4.6350e-02,
             -1.7646e-02, 2.1954e-02, 8.5546e-03, -2.1359e-02, -2.4206e-02,
             -2.3151e-02, -3.6330e-02, 4.4699e-02, 3.9887e-03, 1.5609e-02,
             -4.3950e-02, 1.0955e-02, -2.4277e-02, 1.4915e-02, 3.2508e-03,
             6.1454e-02, 3.5242e-02, -1.5274e-02, -2.6729e-02, -2.4072e-02,
             1.5696e-02, 2.6622e-02, -3.5404e-02, 5.1701e-02, -5.3047e-02,
             -1.8412e-02, 8.6640e-03, -3.1722e-02, 4.0329e-02, 1.2896e-02,
             -1.4139e-02, -4.9200e-02, -4.6193e-02, -2.9064e-03,
             -2.2078e-02, -4.0084e-02, -8.3519e-03, -2.7589e-02,
             -4.9979e-03, -1.6055e-02, -4.5311e-02, -2.6951e-02, 2.8032e-02,
             -4.0069e-03, 3.2405e-02, -5.3164e-03, -3.0139e-03, 6.6179e-04,
             -4.9243e-02, 3.2515e-02, 9.8307e-03, -3.4257e-03, -3.9522e-02,
             1.2594e-02, -2.7210e-02, 2.3451e-02, 4.2257e-02, 2.2239e-02,
             1.4304e-04, 4.2905e-04, 1.5193e-02, 3.1897e-03, -1.0828e-02,
             -4.8345e-02, 6.8747e-02, -7.1725e-03, -9.7815e-03, -1.6331e-02,
             1.0434e-02, -8.8083e-04, 3.8219e-02, 6.8332e-03, -2.0189e-02,
             2.8141e-02, 1.4913e-02, -2.4925e-02, -2.8922e-02, -7.1546e-03,
             1.9791e-02, 1.1160e-02, 1.0306e-02, -1.3631e-02, 2.7318e-03,
             1.4050e-03, -8.2064e-03, 3.5836e-02, -1.5877e-02, -1.1198e-02,
             1.9514e-02, 3.0832e-03, -6.2730e-02, 6.1493e-03, -1.2340e-02,
             3.9110e-02, -2.6895e-02, -5.1718e-03, 7.5017e-03, 1.2673e-03,
             4.7525e-02, 1.7373e-03, -5.1745e-03, -2.8621e-02, 3.4984e-02,
             -3.2622e-02, 1.0748e-02, 1.2499e-02, -1.8788e-02, -8.6717e-03,
             4.3620e-02, 2.8460e-02, -6.8146e-03, -3.5824e-02, 9.2931e-03,
             3.7893e-03, 2.4187e-02, 1.3393e-02, -5.9393e-03, -9.9837e-03,
             -8.1019e-03, -2.1840e-02, -3.8945e-02, 1.6736e-02, -4.7475e-02,
             4.9770e-02, 3.4695e-02, 1.8961e-02, 2.7416e-02, -1.3578e-02,
             -9.8595e-03, 2.2834e-03, 2.4829e-02, -4.3998e-02, 3.2398e-02,
             -1.4200e-02, 2.4907e-02, -2.2542e-02, -9.2765e-03, 2.0658e-03,
             -4.1246e-03, -1.8095e-02, -1.2732e-02, -3.2090e-03, 1.3127e-02,
             -2.0888e-02, 1.4931e-02, -4.0576e-02, 4.2877e-02, 7.9411e-05,
             -4.4377e-02, 3.2357e-03, 1.6201e-02, 4.0387e-02, -1.9023e-02,
             5.8033e-02, -3.3424e-02, 2.9598e-03, -1.8526e-02, -2.2967e-02,
             4.3449e-02, -1.2564e-02, -9.3756e-03, -2.1745e-02, -2.7089e-02,
             -3.6791e-02, -5.2018e-02, 2.4588e-02, 1.0037e-03, 3.9753e-02,
             4.3534e-02, 2.6446e-02, -1.1808e-02, 2.1426e-02, 7.5522e-03,
             2.2847e-03, -2.7211e-02, 4.1364e-02, -1.1281e-02, 1.6523e-03,
             -1.9913e-03
         ]]),
         decimal=3)
     optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
     loss = out.sum()
     loss.backward()
     optimizer.step()
     out = model(x)
     tt.assert_almost_equal(
         out.cpu(),
         torch.tensor([[
             -0.0247, -0.0172, -0.0633, -0.0154, -0.0156, -0.1156, -0.0793,
             -0.0184, -0.0408, 0.0005, -0.0920, -0.0481, -0.0597, -0.0243,
             0.0006, -0.1045, -0.0476, -0.0030, -0.0230, -0.0869, -0.0149,
             -0.0412, -0.0753, -0.0640, -0.1106, -0.0632, -0.0645, -0.0474,
             -0.0124, -0.0698, -0.0275, -0.0415, -0.0916, -0.0957, -0.0851,
             -0.1296, -0.1049, -0.0196, -0.0823, -0.0380, -0.1085, -0.0526,
             -0.0083, -0.1274, -0.0426, -0.0183, -0.0585, -0.0366, -0.1111,
             -0.0074, -0.1238, -0.0324, -0.0166, -0.0719, -0.0285, -0.0427,
             -0.1158, -0.0569, 0.0075, -0.0419, -0.0288, -0.1189, -0.0220,
             -0.0370, 0.0040, 0.0228, -0.0958, -0.0258, -0.0276, -0.0405,
             -0.0958, -0.0201, -0.0639, -0.0543, -0.0705, -0.0940, -0.0700,
             -0.0921, -0.0426, 0.0026, -0.0556, -0.0439, -0.0386, -0.0957,
             -0.0915, -0.0679, -0.1272, -0.0754, -0.0076, -0.1046, -0.0350,
             -0.0887, -0.0350, -0.0270, -0.1188, -0.0449, 0.0020, -0.0406,
             0.0011, -0.0842, -0.0422, -0.1280, -0.0205, 0.0002, -0.0789,
             -0.0185, -0.0510, -0.1180, -0.0550, -0.0159, -0.0702, -0.0029,
             -0.0891, -0.0253, -0.0485, -0.0128, 0.0010, -0.0870, -0.0230,
             -0.0233, -0.0411, -0.0870, -0.0419, -0.0688, -0.0583, -0.0448,
             -0.0864, -0.0926, -0.0758, -0.0540, 0.0058, -0.0843, -0.0365,
             -0.0608, -0.0787, -0.0938, -0.0680, -0.0995, -0.0764, 0.0061,
             -0.0821, -0.0636, -0.0848, -0.0373, -0.0285, -0.1086, -0.0464,
             -0.0228, -0.0464, -0.0279, -0.1053, -0.0224, -0.1268, -0.0006,
             -0.0186, -0.0836, -0.0011, -0.0415, -0.1222, -0.0668, -0.0015,
             -0.0535, -0.0071, -0.1202, -0.0257, -0.0503, 0.0004, 0.0099,
             -0.1113, -0.0182, -0.0080, -0.0216, -0.0661, -0.0115, -0.0468,
             -0.0716, -0.0404, -0.0950, -0.0681, -0.0933, -0.0699, -0.0154,
             -0.0853, -0.0414, -0.0403, -0.0700, -0.0685, -0.0975, -0.0934,
             -0.1016, -0.0121, -0.1084, -0.0391, -0.1006, -0.0441, -0.0024,
             -0.1232, -0.0159, 0.0012, -0.0480, -0.0013, -0.0789, -0.0309,
             -0.1101
         ]]),
         decimal=3)
 def test_atoms(self):
     tt.assert_almost_equal(self.q.atoms, torch.tensor([-2, -1, 0, 1, 2]))
Example #8
0
 def test_tanh_action_bound(self):
     space = gym.spaces.Box(np.array([-1.0, 10.0]), np.array([1, 20]))
     model = nn.TanhActionBound(space)
     x = torch.tensor([[100.0, 100], [-100, -100], [-100, 100], [0, 0]])
     tt.assert_almost_equal(
         model(x), torch.tensor([[1.0, 20], [-1, 10], [-1, 20], [0.0, 15]]))
 def test_project_dist_cuda(self):
     if torch.cuda.is_available():
         # This gave problems in the past between different cuda version,
         # so a test was added.
         q = QDist(self.model.cuda(), self.optimizer, ACTIONS, 51, -10.,
                   10.)
         dist = torch.tensor([[
             0.0190, 0.0197, 0.0200, 0.0190, 0.0195, 0.0198, 0.0194, 0.0192,
             0.0201, 0.0203, 0.0189, 0.0190, 0.0199, 0.0193, 0.0192, 0.0199,
             0.0198, 0.0197, 0.0193, 0.0198, 0.0192, 0.0191, 0.0200, 0.0202,
             0.0191, 0.0202, 0.0198, 0.0200, 0.0198, 0.0193, 0.0192, 0.0202,
             0.0192, 0.0194, 0.0199, 0.0197, 0.0197, 0.0201, 0.0199, 0.0190,
             0.0192, 0.0195, 0.0202, 0.0194, 0.0203, 0.0201, 0.0190, 0.0192,
             0.0201, 0.0201, 0.0192
         ],
                              [
                                  0.0191, 0.0197, 0.0200, 0.0190, 0.0195,
                                  0.0198, 0.0194, 0.0192, 0.0201, 0.0203,
                                  0.0190, 0.0190, 0.0199, 0.0193, 0.0192,
                                  0.0199, 0.0198, 0.0197, 0.0193, 0.0198,
                                  0.0192, 0.0191, 0.0200, 0.0202, 0.0191,
                                  0.0202, 0.0198, 0.0200, 0.0198, 0.0193,
                                  0.0192, 0.0202, 0.0192, 0.0194, 0.0199,
                                  0.0197, 0.0197, 0.0200, 0.0199, 0.0190,
                                  0.0192, 0.0195, 0.0202, 0.0194, 0.0203,
                                  0.0201, 0.0190, 0.0192, 0.0201, 0.0200,
                                  0.0192
                              ],
                              [
                                  0.0191, 0.0197, 0.0200, 0.0190, 0.0195,
                                  0.0198, 0.0194, 0.0192, 0.0200, 0.0203,
                                  0.0190, 0.0191, 0.0199, 0.0193, 0.0192,
                                  0.0199, 0.0198, 0.0197, 0.0193, 0.0198,
                                  0.0192, 0.0191, 0.0199, 0.0202, 0.0192,
                                  0.0202, 0.0198, 0.0200, 0.0198, 0.0193,
                                  0.0192, 0.0202, 0.0192, 0.0194, 0.0199,
                                  0.0197, 0.0197, 0.0200, 0.0199, 0.0190,
                                  0.0192, 0.0195, 0.0202, 0.0194, 0.0203,
                                  0.0201, 0.0190, 0.0192, 0.0201, 0.0200,
                                  0.0192
                              ]]).cuda()
         support = torch.tensor(
             [[
                 -9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624,
                 -7.3743, -6.9862, -6.5980, -6.2099, -5.8218, -5.4337,
                 -5.0456, -4.6574, -4.2693, -3.8812, -3.4931, -3.1050,
                 -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762,
                 -0.3881, 0.0000, 0.3881, 0.7762, 1.1644, 1.5525, 1.9406,
                 2.3287, 2.7168, 3.1050, 3.4931, 3.8812, 4.2693, 4.6574,
                 5.0456, 5.4337, 5.8218, 6.2099, 6.5980, 6.9862, 7.3743,
                 7.7624, 8.1505, 8.5386, 8.9268, 9.3149, 9.7030
             ],
              [
                  -9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624,
                  -7.3743, -6.9862, -6.5980, -6.2099, -5.8218, -5.4337,
                  -5.0456, -4.6574, -4.2693, -3.8812, -3.4931, -3.1050,
                  -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762,
                  -0.3881, 0.0000, 0.3881, 0.7762, 1.1644, 1.5525, 1.9406,
                  2.3287, 2.7168, 3.1050, 3.4931, 3.8812, 4.2693, 4.6574,
                  5.0456, 5.4337, 5.8218, 6.2099, 6.5980, 6.9862, 7.3743,
                  7.7624, 8.1505, 8.5386, 8.9268, 9.3149, 9.7030
              ],
              [
                  -9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624,
                  -7.3743, -6.9862, -6.5980, -6.2099, -5.8218, -5.4337,
                  -5.0456, -4.6574, -4.2693, -3.8812, -3.4931, -3.1050,
                  -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762,
                  -0.3881, 0.0000, 0.3881, 0.7762, 1.1644, 1.5525, 1.9406,
                  2.3287, 2.7168, 3.1050, 3.4931, 3.8812, 4.2693, 4.6574,
                  5.0456, 5.4337, 5.8218, 6.2099, 6.5980, 6.9862, 7.3743,
                  7.7624, 8.1505, 8.5386, 8.9268, 9.3149, 9.7030
              ]]).cuda()
         expected = torch.tensor(
             [[
                 0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202,
                 0.0199, 0.0202, 0.0208, 0.0201, 0.0195, 0.0201, 0.0201,
                 0.0198, 0.0203, 0.0204, 0.0203, 0.0200, 0.0203, 0.0199,
                 0.0197, 0.0205, 0.0208, 0.0197, 0.0214, 0.0204, 0.0206,
                 0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204,
                 0.0203, 0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204,
                 0.0204, 0.0205, 0.0208, 0.0200, 0.0197, 0.0204, 0.0207,
                 0.0200, 0.0049
             ],
              [
                  0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202,
                  0.0199, 0.0202, 0.0208, 0.0202, 0.0196, 0.0201, 0.0201,
                  0.0198, 0.0203, 0.0204, 0.0203, 0.0200, 0.0203, 0.0199,
                  0.0197, 0.0205, 0.0208, 0.0197, 0.0214, 0.0204, 0.0206,
                  0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204,
                  0.0203, 0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204,
                  0.0204, 0.0205, 0.0208, 0.0200, 0.0197, 0.0204, 0.0206,
                  0.0200, 0.0049
              ],
              [
                  0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202,
                  0.0199, 0.0202, 0.0208, 0.0202, 0.0196, 0.0202, 0.0201,
                  0.0198, 0.0203, 0.0204, 0.0203, 0.0200, 0.0203, 0.0199,
                  0.0197, 0.0204, 0.0208, 0.0198, 0.0214, 0.0204, 0.0206,
                  0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204,
                  0.0203, 0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204,
                  0.0204, 0.0205, 0.0208, 0.0200, 0.0197, 0.0204, 0.0206,
                  0.0200, 0.0049
              ]])
         tt.assert_almost_equal(q.project(dist, support).cpu(),
                                expected.cpu(),
                                decimal=3)