예제 #1
0
    def agent(self, writer=DummyWriter(), train_steps=float('inf')):
        optimizer = Adam(self.model.parameters(),
                         lr=self.hyperparameters['lr'])

        q = QDist(
            self.model,
            optimizer,
            self.n_actions,
            self.hyperparameters['atoms'],
            v_min=self.hyperparameters['v_min'],
            v_max=self.hyperparameters['v_max'],
            target=FixedTarget(
                self.hyperparameters['target_update_frequency']),
            writer=writer,
        )

        replay_buffer = ExperienceReplayBuffer(
            self.hyperparameters['replay_buffer_size'], device=self.device)

        return C51(q,
                   replay_buffer,
                   exploration=LinearScheduler(
                       self.hyperparameters['initial_exploration'],
                       self.hyperparameters['final_exploration'],
                       0,
                       self.hyperparameters["final_exploration_step"] -
                       self.hyperparameters["replay_start_size"],
                       name="epsilon",
                       writer=writer,
                   ),
                   discount_factor=self.hyperparameters["discount_factor"],
                   minibatch_size=self.hyperparameters["minibatch_size"],
                   replay_start_size=self.hyperparameters["replay_start_size"],
                   update_frequency=self.hyperparameters["update_frequency"],
                   writer=writer)
예제 #2
0
 def _c51(env, writer=DummyWriter()):
     model = fc_relu_dist_q(env, atoms=atoms).to(device)
     optimizer = Adam(model.parameters(), lr=lr)
     q = QDist(
         model,
         optimizer,
         env.action_space.n,
         atoms,
         v_min=v_min,
         v_max=v_max,
         writer=writer,
     )
     replay_buffer = ExperienceReplayBuffer(replay_buffer_size,
                                            device=device)
     return C51(q,
                replay_buffer,
                exploration=LinearScheduler(
                    initial_exploration,
                    final_exploration,
                    replay_start_size,
                    final_exploration_frame,
                    name="epsilon",
                    writer=writer,
                ),
                discount_factor=discount_factor,
                minibatch_size=minibatch_size,
                replay_start_size=replay_start_size,
                update_frequency=update_frequency,
                writer=writer)
예제 #3
0
 def _rainbow(env, writer=DummyWriter()):
     model = model_constructor(env, atoms=atoms, sigma=sigma).to(device)
     optimizer = Adam(model.parameters(), lr=lr)
     q = QDist(
         model,
         optimizer,
         env.action_space.n,
         atoms,
         v_min=v_min,
         v_max=v_max,
         writer=writer,
     )
     # replay_buffer = ExperienceReplayBuffer(replay_buffer_size, device=device)
     replay_buffer = PrioritizedReplayBuffer(
         replay_buffer_size,
         alpha=alpha,
         beta=beta,
         device=device
     )
     replay_buffer = NStepReplayBuffer(n_steps, discount_factor, replay_buffer)
     return Rainbow(
         q,
         replay_buffer,
         exploration=0.,
         discount_factor=discount_factor ** n_steps,
         minibatch_size=minibatch_size,
         replay_start_size=replay_start_size,
         update_frequency=update_frequency,
         writer=writer,
     )
예제 #4
0
    def agent(self, writer=DummyWriter(), train_steps=float('inf')):
        n_updates = (train_steps - self.hyperparameters['replay_start_size']) / self.hyperparameters['update_frequency']

        optimizer = Adam(
            self.model.parameters(),
            lr=self.hyperparameters['lr'],
            eps=self.hyperparameters['eps']
        )

        q_dist = QDist(
            self.model,
            optimizer,
            self.n_actions,
            self.hyperparameters['atoms'],
            scheduler=CosineAnnealingLR(optimizer, n_updates),
            v_min=self.hyperparameters['v_min'],
            v_max=self.hyperparameters['v_max'],
            target=FixedTarget(self.hyperparameters['target_update_frequency']),
            writer=writer,
        )

        replay_buffer = NStepReplayBuffer(
            self.hyperparameters['n_steps'],
            self.hyperparameters['discount_factor'],
            PrioritizedReplayBuffer(
                self.hyperparameters['replay_buffer_size'],
                alpha=self.hyperparameters['alpha'],
                beta=self.hyperparameters['beta'],
                device=self.device
            )
        )
        def agent_constructor(writer):
            return DeepmindAtariBody(
                Rainbow(
                    q_dist,
                    replay_buffer,
                    exploration=LinearScheduler(
                        self.hyperparameters['initial_exploration'],
                        self.hyperparameters['final_exploration'],
                        0,
                        train_steps - self.hyperparameters['replay_start_size'],
                        name="exploration",
                        writer=writer
                    ),
                    discount_factor=self.hyperparameters['discount_factor'] ** self.hyperparameters["n_steps"],
                    minibatch_size=self.hyperparameters['minibatch_size'],
                    replay_start_size=self.hyperparameters['replay_start_size'],
                    update_frequency=self.hyperparameters['update_frequency'],
                    writer=writer,
                ),
                lazy_frames=True,
                episodic_lives=True
            )

        return MultiagentEncoder(IndependentMultiagent({
            agent : agent_constructor(writers[agent])
            for agent in env.agents
        }), env.agents, device)
예제 #5
0
 def test_agent(self):
     q_dist = QDist(
         copy.deepcopy(self.model),
         None,
         self.n_actions,
         self.hyperparameters['atoms'],
         v_min=self.hyperparameters['v_min'],
         v_max=self.hyperparameters['v_max'],
     )
     return DeepmindAtariBody(C51TestAgent(q_dist, self.n_actions, self.hyperparameters["test_exploration"]))
예제 #6
0
 def agent_constructor():
     q_dist = QDist(
         self.model,
         None,
         self.n_actions,
         self.hyperparameters['atoms'],
         v_min=self.hyperparameters['v_min'],
         v_max=self.hyperparameters['v_max'],
     )
     return DeepmindAtariBody(RainbowTestAgent(q_dist, self.n_actions, self.hyperparameters["test_exploration"]))
예제 #7
0
    def agent(self, writer=DummyWriter(), train_steps=float('inf')):
        n_updates = (train_steps - self.hyperparameters['replay_start_size']) / self.hyperparameters['update_frequency']

        optimizer = Adam(
            self.model.parameters(),
            lr=self.hyperparameters['lr'],
            eps=self.hyperparameters['eps']
        )

        q = QDist(
            self.model,
            optimizer,
            self.n_actions,
            self.hyperparameters['atoms'],
            v_min=self.hyperparameters['v_min'],
            v_max=self.hyperparameters['v_max'],
            target=FixedTarget(self.hyperparameters['target_update_frequency']),
            scheduler=CosineAnnealingLR(optimizer, n_updates),
            writer=writer,
        )

        replay_buffer = ExperienceReplayBuffer(
            self.hyperparameters['replay_buffer_size'],
            device=self.device
        )

        return DeepmindAtariBody(
            C51(
                q,
                replay_buffer,
                exploration=LinearScheduler(
                    self.hyperparameters['initial_exploration'],
                    self.hyperparameters['final_exploration'],
                    0,
                    self.hyperparameters["final_exploration_step"] - self.hyperparameters["replay_start_size"],
                    name="epsilon",
                    writer=writer,
                ),
                discount_factor=self.hyperparameters["discount_factor"],
                minibatch_size=self.hyperparameters["minibatch_size"],
                replay_start_size=self.hyperparameters["replay_start_size"],
                update_frequency=self.hyperparameters["update_frequency"],
                writer=writer
            ),
            lazy_frames=True,
            episodic_lives=True
        )
예제 #8
0
    def _c51(env, writer=DummyWriter()):
        action_repeat = 4
        last_timestep = last_frame / action_repeat
        last_update = (last_timestep - replay_start_size) / update_frequency

        model = nature_c51(env, atoms=atoms).to(device)
        optimizer = Adam(
            model.parameters(),
            lr=lr,
            eps=eps
        )
        q = QDist(
            model,
            optimizer,
            env.action_space.n,
            atoms,
            v_min=v_min,
            v_max=v_max,
            target=FixedTarget(target_update_frequency),
            scheduler=CosineAnnealingLR(optimizer, last_update),
            writer=writer,
        )
        replay_buffer = ExperienceReplayBuffer(
            replay_buffer_size,
            device=device
        )
        return DeepmindAtariBody(
            C51(
                q,
                replay_buffer,
                exploration=LinearScheduler(
                    initial_exploration,
                    final_exploration,
                    0,
                    last_timestep,
                    name="epsilon",
                    writer=writer,
                ),
                discount_factor=discount_factor,
                minibatch_size=minibatch_size,
                replay_start_size=replay_start_size,
                update_frequency=update_frequency,
                writer=writer
            ),
            lazy_frames=True
        )
예제 #9
0
    def agent(self, writer=DummyWriter(), train_steps=float('inf')):
        optimizer = Adam(
            self.model.parameters(),
            lr=self.hyperparameters['lr'],
            eps=self.hyperparameters['eps']
        )

        q_dist = QDist(
            self.model,
            optimizer,
            self.n_actions,
            self.hyperparameters['atoms'],
            v_min=self.hyperparameters['v_min'],
            v_max=self.hyperparameters['v_max'],
            target=FixedTarget(self.hyperparameters['target_update_frequency']),
            writer=writer,
        )

        replay_buffer = NStepReplayBuffer(
            self.hyperparameters['n_steps'],
            self.hyperparameters['discount_factor'],
            PrioritizedReplayBuffer(
                self.hyperparameters['replay_buffer_size'],
                alpha=self.hyperparameters['alpha'],
                beta=self.hyperparameters['beta'],
                device=self.device
            )
        )

        return Rainbow(
            q_dist,
            replay_buffer,
            exploration=LinearScheduler(
                self.hyperparameters['initial_exploration'],
                self.hyperparameters['final_exploration'],
                0,
                train_steps - self.hyperparameters['replay_start_size'],
                name="exploration",
                writer=writer
            ),
            discount_factor=self.hyperparameters['discount_factor'] ** self.hyperparameters["n_steps"],
            minibatch_size=self.hyperparameters['minibatch_size'],
            replay_start_size=self.hyperparameters['replay_start_size'],
            update_frequency=self.hyperparameters['update_frequency'],
            writer=writer,
        )
    def _rainbow(env, writer=DummyWriter()):
        action_repeat = 4
        last_timestep = last_frame / action_repeat
        last_update = (last_timestep - replay_start_size) / update_frequency

        model = model_constructor(env, atoms=atoms, sigma=sigma).to(device)
        optimizer = Adam(model.parameters(), lr=lr, eps=eps)
        q = QDist(
            model,
            optimizer,
            env.action_space.n,
            atoms,
            scheduler=CosineAnnealingLR(optimizer, last_update),
            v_min=v_min,
            v_max=v_max,
            target=FixedTarget(target_update_frequency),
            writer=writer,
        )
        replay_buffer = PrioritizedReplayBuffer(replay_buffer_size,
                                                alpha=alpha,
                                                beta=beta,
                                                device=device)
        replay_buffer = NStepReplayBuffer(n_steps, discount_factor,
                                          replay_buffer)

        agent = Rainbow(
            q,
            replay_buffer,
            exploration=LinearScheduler(initial_exploration,
                                        final_exploration,
                                        0,
                                        last_timestep,
                                        name='exploration',
                                        writer=writer),
            discount_factor=discount_factor**n_steps,
            minibatch_size=minibatch_size,
            replay_start_size=replay_start_size,
            update_frequency=update_frequency,
            writer=writer,
        )
        return DeepmindAtariBody(agent, lazy_frames=True, episodic_lives=True)
예제 #11
0
 def _c51(env, writer=DummyWriter()):
     model = nature_c51(env, atoms=51).to(device)
     optimizer = Adam(
         model.parameters(),
         lr=lr,
         eps=eps
     )
     q = QDist(
         model,
         optimizer,
         env.action_space.n,
         atoms,
         v_min=v_min,
         v_max=v_max,
         target=FixedTarget(target_update_frequency),
         writer=writer,
     )
     replay_buffer = ExperienceReplayBuffer(
         replay_buffer_size,
         device=device
     )
     return DeepmindAtariBody(
         C51(
             q,
             replay_buffer,
             exploration=LinearScheduler(
                 initial_exploration,
                 final_exploration,
                 replay_start_size,
                 final_exploration_frame,
                 name="epsilon",
                 writer=writer,
             ),
             discount_factor=discount_factor,
             minibatch_size=minibatch_size,
             replay_start_size=replay_start_size,
             update_frequency=update_frequency,
             writer=writer
         )
     )
예제 #12
0
 def test_project_dist_cuda(self):
     if torch.cuda.is_available():
         # This gave problems in the past between different cuda version,
         # so a test was added.
         q = QDist(self.model.cuda(), self.optimizer, ACTIONS, 51, -10., 10.)
         dist = torch.tensor([
             [0.0190, 0.0197, 0.0200, 0.0190, 0.0195, 0.0198, 0.0194, 0.0192, 0.0201,
             0.0203, 0.0189, 0.0190, 0.0199, 0.0193, 0.0192, 0.0199, 0.0198, 0.0197,
             0.0193, 0.0198, 0.0192, 0.0191, 0.0200, 0.0202, 0.0191, 0.0202, 0.0198,
             0.0200, 0.0198, 0.0193, 0.0192, 0.0202, 0.0192, 0.0194, 0.0199, 0.0197,
             0.0197, 0.0201, 0.0199, 0.0190, 0.0192, 0.0195, 0.0202, 0.0194, 0.0203,
             0.0201, 0.0190, 0.0192, 0.0201, 0.0201, 0.0192],
             [0.0191, 0.0197, 0.0200, 0.0190, 0.0195, 0.0198, 0.0194, 0.0192, 0.0201,
             0.0203, 0.0190, 0.0190, 0.0199, 0.0193, 0.0192, 0.0199, 0.0198, 0.0197,
             0.0193, 0.0198, 0.0192, 0.0191, 0.0200, 0.0202, 0.0191, 0.0202, 0.0198,
             0.0200, 0.0198, 0.0193, 0.0192, 0.0202, 0.0192, 0.0194, 0.0199, 0.0197,
             0.0197, 0.0200, 0.0199, 0.0190, 0.0192, 0.0195, 0.0202, 0.0194, 0.0203,
             0.0201, 0.0190, 0.0192, 0.0201, 0.0200, 0.0192],
             [0.0191, 0.0197, 0.0200, 0.0190, 0.0195, 0.0198, 0.0194, 0.0192, 0.0200,
             0.0203, 0.0190, 0.0191, 0.0199, 0.0193, 0.0192, 0.0199, 0.0198, 0.0197,
             0.0193, 0.0198, 0.0192, 0.0191, 0.0199, 0.0202, 0.0192, 0.0202, 0.0198,
             0.0200, 0.0198, 0.0193, 0.0192, 0.0202, 0.0192, 0.0194, 0.0199, 0.0197,
             0.0197, 0.0200, 0.0199, 0.0190, 0.0192, 0.0195, 0.0202, 0.0194, 0.0203,
             0.0201, 0.0190, 0.0192, 0.0201, 0.0200, 0.0192]
         ]).cuda()
         support = torch.tensor([
             [-9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624, -7.3743, -6.9862,
             -6.5980, -6.2099, -5.8218, -5.4337, -5.0456, -4.6574, -4.2693, -3.8812,
             -3.4931, -3.1050, -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762,
             -0.3881,  0.0000,  0.3881,  0.7762,  1.1644,  1.5525,  1.9406,  2.3287,
             2.7168,  3.1050,  3.4931,  3.8812,  4.2693,  4.6574,  5.0456,  5.4337,
             5.8218,  6.2099,  6.5980,  6.9862,  7.3743,  7.7624,  8.1505,  8.5386,
             8.9268,  9.3149,  9.7030],
             [-9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624, -7.3743, -6.9862,
             -6.5980, -6.2099, -5.8218, -5.4337, -5.0456, -4.6574, -4.2693, -3.8812,
             -3.4931, -3.1050, -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762,
             -0.3881,  0.0000,  0.3881,  0.7762,  1.1644,  1.5525,  1.9406,  2.3287,
             2.7168,  3.1050,  3.4931,  3.8812,  4.2693,  4.6574,  5.0456,  5.4337,
             5.8218,  6.2099,  6.5980,  6.9862,  7.3743,  7.7624,  8.1505,  8.5386,
             8.9268,  9.3149,  9.7030],
             [-9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624, -7.3743, -6.9862,
             -6.5980, -6.2099, -5.8218, -5.4337, -5.0456, -4.6574, -4.2693, -3.8812,
             -3.4931, -3.1050, -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762,
             -0.3881,  0.0000,  0.3881,  0.7762,  1.1644,  1.5525,  1.9406,  2.3287,
             2.7168,  3.1050,  3.4931,  3.8812,  4.2693,  4.6574,  5.0456,  5.4337,
             5.8218,  6.2099,  6.5980,  6.9862,  7.3743,  7.7624,  8.1505,  8.5386,
             8.9268,  9.3149,  9.7030]
         ]).cuda()
         expected = torch.tensor([
             [0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202, 0.0199, 0.0202,
             0.0208, 0.0201, 0.0195, 0.0201, 0.0201, 0.0198, 0.0203, 0.0204, 0.0203,
             0.0200, 0.0203, 0.0199, 0.0197, 0.0205, 0.0208, 0.0197, 0.0214, 0.0204,
             0.0206, 0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204, 0.0203,
             0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204, 0.0204, 0.0205, 0.0208,
             0.0200, 0.0197, 0.0204, 0.0207, 0.0200, 0.0049],
             [0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202, 0.0199, 0.0202,
             0.0208, 0.0202, 0.0196, 0.0201, 0.0201, 0.0198, 0.0203, 0.0204, 0.0203,
             0.0200, 0.0203, 0.0199, 0.0197, 0.0205, 0.0208, 0.0197, 0.0214, 0.0204,
             0.0206, 0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204, 0.0203,
             0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204, 0.0204, 0.0205, 0.0208,
             0.0200, 0.0197, 0.0204, 0.0206, 0.0200, 0.0049],
             [0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202, 0.0199, 0.0202,
             0.0208, 0.0202, 0.0196, 0.0202, 0.0201, 0.0198, 0.0203, 0.0204, 0.0203,
             0.0200, 0.0203, 0.0199, 0.0197, 0.0204, 0.0208, 0.0198, 0.0214, 0.0204,
             0.0206, 0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204, 0.0203,
             0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204, 0.0204, 0.0205, 0.0208,
             0.0200, 0.0197, 0.0204, 0.0206, 0.0200, 0.0049]
         ])
         tt.assert_almost_equal(q.project(dist, support).cpu(), expected.cpu(), decimal=3)
예제 #13
0
 def setUp(self):
     torch.manual_seed(2)
     self.model = nn.Sequential(nn.Linear(STATE_DIM, ACTIONS * ATOMS))
     self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
     self.q = QDist(self.model, self.optimizer, ACTIONS, ATOMS, V_MIN, V_MAX)
예제 #14
0
class TestQDist(unittest.TestCase):
    def setUp(self):
        torch.manual_seed(2)
        self.model = nn.Sequential(nn.Linear(STATE_DIM, ACTIONS * ATOMS))
        self.optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
        self.q = QDist(self.model, self.optimizer, ACTIONS, ATOMS, V_MIN, V_MAX)

    def test_atoms(self):
        tt.assert_almost_equal(self.q.atoms, torch.tensor([-2, -1, 0, 1, 2]))

    def test_q_values(self):
        states = StateArray(torch.randn((3, STATE_DIM)), (3,))
        probs = self.q(states)
        self.assertEqual(probs.shape, (3, ACTIONS, ATOMS))
        tt.assert_almost_equal(
            probs.sum(dim=2),
            torch.tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]),
            decimal=3,
        )
        tt.assert_almost_equal(
            probs,
            torch.tensor(
                [
                    [
                        [0.2065, 0.1045, 0.1542, 0.2834, 0.2513],
                        [0.3903, 0.2471, 0.0360, 0.1733, 0.1533],
                    ],
                    [
                        [0.1966, 0.1299, 0.1431, 0.3167, 0.2137],
                        [0.3190, 0.2471, 0.0534, 0.1424, 0.2380],
                    ],
                    [
                        [0.1427, 0.2486, 0.0946, 0.4112, 0.1029],
                        [0.0819, 0.1320, 0.1203, 0.0373, 0.6285],
                    ],
                ]
            ),
            decimal=3,
        )

    def test_single_q_values(self):
        states = StateArray(torch.randn((3, STATE_DIM)), (3,))
        actions = torch.tensor([0, 1, 0])
        probs = self.q(states, actions)
        self.assertEqual(probs.shape, (3, ATOMS))
        tt.assert_almost_equal(
            probs.sum(dim=1), torch.tensor([1.0, 1.0, 1.0]), decimal=3
        )
        tt.assert_almost_equal(
            probs,
            torch.tensor(
                [
                    [0.2065, 0.1045, 0.1542, 0.2834, 0.2513],
                    [0.3190, 0.2471, 0.0534, 0.1424, 0.2380],
                    [0.1427, 0.2486, 0.0946, 0.4112, 0.1029],
                ]
            ),
            decimal=3,
        )

    def test_done(self):
        states = StateArray(torch.randn((3, STATE_DIM)), (3,), mask=torch.tensor([1, 0, 1]))
        probs = self.q(states)
        self.assertEqual(probs.shape, (3, ACTIONS, ATOMS))
        tt.assert_almost_equal(
            probs.sum(dim=2),
            torch.tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]),
            decimal=3,
        )
        tt.assert_almost_equal(
            probs,
            torch.tensor(
                [
                    [
                        [0.2065, 0.1045, 0.1542, 0.2834, 0.2513],
                        [0.3903, 0.2471, 0.0360, 0.1733, 0.1533],
                    ],
                    [[0, 0, 1, 0, 0], [0, 0, 1, 0, 0]],
                    [
                        [0.1427, 0.2486, 0.0946, 0.4112, 0.1029],
                        [0.0819, 0.1320, 0.1203, 0.0373, 0.6285],
                    ],
                ]
            ),
            decimal=3,
        )

    def test_reinforce(self):
        states = StateArray(torch.randn((3, STATE_DIM)), (3,))
        actions = torch.tensor([0, 1, 0])
        original_probs = self.q(states, actions)
        tt.assert_almost_equal(
            original_probs,
            torch.tensor(
                [
                    [0.2065, 0.1045, 0.1542, 0.2834, 0.2513],
                    [0.3190, 0.2471, 0.0534, 0.1424, 0.2380],
                    [0.1427, 0.2486, 0.0946, 0.4112, 0.1029],
                ]
            ),
            decimal=3,
        )

        target_dists = torch.tensor(
            [[0, 0, 1, 0, 0], [0, 0, 0, 0, 1], [0, 1, 0, 0, 0]]
        ).float()

        def _loss(dist, target_dist):
            log_dist = torch.log(torch.clamp(dist, min=1e-5))
            log_target_dist = torch.log(torch.clamp(target_dist, min=1e-5))
            return (target_dist * (log_target_dist - log_dist)).sum(dim=-1).mean()

        self.q.reinforce(_loss(original_probs, target_dists))

        new_probs = self.q(states, actions)
        tt.assert_almost_equal(
            torch.sign(new_probs - original_probs), torch.sign(target_dists - 0.5)
        )

    def test_project_dist(self):
        # This gave problems in the past between different cuda version,
        # so a test was added.
        q = QDist(self.model, self.optimizer, ACTIONS, 51, -10., 10.)
        dist = torch.tensor([
            [0.0190, 0.0197, 0.0200, 0.0190, 0.0195, 0.0198, 0.0194, 0.0192, 0.0201,
            0.0203, 0.0189, 0.0190, 0.0199, 0.0193, 0.0192, 0.0199, 0.0198, 0.0197,
            0.0193, 0.0198, 0.0192, 0.0191, 0.0200, 0.0202, 0.0191, 0.0202, 0.0198,
            0.0200, 0.0198, 0.0193, 0.0192, 0.0202, 0.0192, 0.0194, 0.0199, 0.0197,
            0.0197, 0.0201, 0.0199, 0.0190, 0.0192, 0.0195, 0.0202, 0.0194, 0.0203,
            0.0201, 0.0190, 0.0192, 0.0201, 0.0201, 0.0192],
            [0.0191, 0.0197, 0.0200, 0.0190, 0.0195, 0.0198, 0.0194, 0.0192, 0.0201,
            0.0203, 0.0190, 0.0190, 0.0199, 0.0193, 0.0192, 0.0199, 0.0198, 0.0197,
            0.0193, 0.0198, 0.0192, 0.0191, 0.0200, 0.0202, 0.0191, 0.0202, 0.0198,
            0.0200, 0.0198, 0.0193, 0.0192, 0.0202, 0.0192, 0.0194, 0.0199, 0.0197,
            0.0197, 0.0200, 0.0199, 0.0190, 0.0192, 0.0195, 0.0202, 0.0194, 0.0203,
            0.0201, 0.0190, 0.0192, 0.0201, 0.0200, 0.0192],
            [0.0191, 0.0197, 0.0200, 0.0190, 0.0195, 0.0198, 0.0194, 0.0192, 0.0200,
            0.0203, 0.0190, 0.0191, 0.0199, 0.0193, 0.0192, 0.0199, 0.0198, 0.0197,
            0.0193, 0.0198, 0.0192, 0.0191, 0.0199, 0.0202, 0.0192, 0.0202, 0.0198,
            0.0200, 0.0198, 0.0193, 0.0192, 0.0202, 0.0192, 0.0194, 0.0199, 0.0197,
            0.0197, 0.0200, 0.0199, 0.0190, 0.0192, 0.0195, 0.0202, 0.0194, 0.0203,
            0.0201, 0.0190, 0.0192, 0.0201, 0.0200, 0.0192]
        ])
        support = torch.tensor([
            [-9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624, -7.3743, -6.9862,
            -6.5980, -6.2099, -5.8218, -5.4337, -5.0456, -4.6574, -4.2693, -3.8812,
            -3.4931, -3.1050, -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762,
            -0.3881,  0.0000,  0.3881,  0.7762,  1.1644,  1.5525,  1.9406,  2.3287,
            2.7168,  3.1050,  3.4931,  3.8812,  4.2693,  4.6574,  5.0456,  5.4337,
            5.8218,  6.2099,  6.5980,  6.9862,  7.3743,  7.7624,  8.1505,  8.5386,
            8.9268,  9.3149,  9.7030],
            [-9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624, -7.3743, -6.9862,
            -6.5980, -6.2099, -5.8218, -5.4337, -5.0456, -4.6574, -4.2693, -3.8812,
            -3.4931, -3.1050, -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762,
            -0.3881,  0.0000,  0.3881,  0.7762,  1.1644,  1.5525,  1.9406,  2.3287,
            2.7168,  3.1050,  3.4931,  3.8812,  4.2693,  4.6574,  5.0456,  5.4337,
            5.8218,  6.2099,  6.5980,  6.9862,  7.3743,  7.7624,  8.1505,  8.5386,
            8.9268,  9.3149,  9.7030],
            [-9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624, -7.3743, -6.9862,
            -6.5980, -6.2099, -5.8218, -5.4337, -5.0456, -4.6574, -4.2693, -3.8812,
            -3.4931, -3.1050, -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762,
            -0.3881,  0.0000,  0.3881,  0.7762,  1.1644,  1.5525,  1.9406,  2.3287,
            2.7168,  3.1050,  3.4931,  3.8812,  4.2693,  4.6574,  5.0456,  5.4337,
            5.8218,  6.2099,  6.5980,  6.9862,  7.3743,  7.7624,  8.1505,  8.5386,
            8.9268,  9.3149,  9.7030]
        ])
        expected = torch.tensor([
            [0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202, 0.0199, 0.0202,
            0.0208, 0.0201, 0.0195, 0.0201, 0.0201, 0.0198, 0.0203, 0.0204, 0.0203,
            0.0200, 0.0203, 0.0199, 0.0197, 0.0205, 0.0208, 0.0197, 0.0214, 0.0204,
            0.0206, 0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204, 0.0203,
            0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204, 0.0204, 0.0205, 0.0208,
            0.0200, 0.0197, 0.0204, 0.0207, 0.0200, 0.0049],
            [0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202, 0.0199, 0.0202,
            0.0208, 0.0202, 0.0196, 0.0201, 0.0201, 0.0198, 0.0203, 0.0204, 0.0203,
            0.0200, 0.0203, 0.0199, 0.0197, 0.0205, 0.0208, 0.0197, 0.0214, 0.0204,
            0.0206, 0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204, 0.0203,
            0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204, 0.0204, 0.0205, 0.0208,
            0.0200, 0.0197, 0.0204, 0.0206, 0.0200, 0.0049],
            [0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202, 0.0199, 0.0202,
            0.0208, 0.0202, 0.0196, 0.0202, 0.0201, 0.0198, 0.0203, 0.0204, 0.0203,
            0.0200, 0.0203, 0.0199, 0.0197, 0.0204, 0.0208, 0.0198, 0.0214, 0.0204,
            0.0206, 0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204, 0.0203,
            0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204, 0.0204, 0.0205, 0.0208,
            0.0200, 0.0197, 0.0204, 0.0206, 0.0200, 0.0049]
        ])
        tt.assert_almost_equal(q.project(dist, support).cpu(), expected.cpu(), decimal=3)

    def test_project_dist_cuda(self):
        if torch.cuda.is_available():
            # This gave problems in the past between different cuda version,
            # so a test was added.
            q = QDist(self.model.cuda(), self.optimizer, ACTIONS, 51, -10., 10.)
            dist = torch.tensor([
                [0.0190, 0.0197, 0.0200, 0.0190, 0.0195, 0.0198, 0.0194, 0.0192, 0.0201,
                0.0203, 0.0189, 0.0190, 0.0199, 0.0193, 0.0192, 0.0199, 0.0198, 0.0197,
                0.0193, 0.0198, 0.0192, 0.0191, 0.0200, 0.0202, 0.0191, 0.0202, 0.0198,
                0.0200, 0.0198, 0.0193, 0.0192, 0.0202, 0.0192, 0.0194, 0.0199, 0.0197,
                0.0197, 0.0201, 0.0199, 0.0190, 0.0192, 0.0195, 0.0202, 0.0194, 0.0203,
                0.0201, 0.0190, 0.0192, 0.0201, 0.0201, 0.0192],
                [0.0191, 0.0197, 0.0200, 0.0190, 0.0195, 0.0198, 0.0194, 0.0192, 0.0201,
                0.0203, 0.0190, 0.0190, 0.0199, 0.0193, 0.0192, 0.0199, 0.0198, 0.0197,
                0.0193, 0.0198, 0.0192, 0.0191, 0.0200, 0.0202, 0.0191, 0.0202, 0.0198,
                0.0200, 0.0198, 0.0193, 0.0192, 0.0202, 0.0192, 0.0194, 0.0199, 0.0197,
                0.0197, 0.0200, 0.0199, 0.0190, 0.0192, 0.0195, 0.0202, 0.0194, 0.0203,
                0.0201, 0.0190, 0.0192, 0.0201, 0.0200, 0.0192],
                [0.0191, 0.0197, 0.0200, 0.0190, 0.0195, 0.0198, 0.0194, 0.0192, 0.0200,
                0.0203, 0.0190, 0.0191, 0.0199, 0.0193, 0.0192, 0.0199, 0.0198, 0.0197,
                0.0193, 0.0198, 0.0192, 0.0191, 0.0199, 0.0202, 0.0192, 0.0202, 0.0198,
                0.0200, 0.0198, 0.0193, 0.0192, 0.0202, 0.0192, 0.0194, 0.0199, 0.0197,
                0.0197, 0.0200, 0.0199, 0.0190, 0.0192, 0.0195, 0.0202, 0.0194, 0.0203,
                0.0201, 0.0190, 0.0192, 0.0201, 0.0200, 0.0192]
            ]).cuda()
            support = torch.tensor([
                [-9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624, -7.3743, -6.9862,
                -6.5980, -6.2099, -5.8218, -5.4337, -5.0456, -4.6574, -4.2693, -3.8812,
                -3.4931, -3.1050, -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762,
                -0.3881,  0.0000,  0.3881,  0.7762,  1.1644,  1.5525,  1.9406,  2.3287,
                2.7168,  3.1050,  3.4931,  3.8812,  4.2693,  4.6574,  5.0456,  5.4337,
                5.8218,  6.2099,  6.5980,  6.9862,  7.3743,  7.7624,  8.1505,  8.5386,
                8.9268,  9.3149,  9.7030],
                [-9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624, -7.3743, -6.9862,
                -6.5980, -6.2099, -5.8218, -5.4337, -5.0456, -4.6574, -4.2693, -3.8812,
                -3.4931, -3.1050, -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762,
                -0.3881,  0.0000,  0.3881,  0.7762,  1.1644,  1.5525,  1.9406,  2.3287,
                2.7168,  3.1050,  3.4931,  3.8812,  4.2693,  4.6574,  5.0456,  5.4337,
                5.8218,  6.2099,  6.5980,  6.9862,  7.3743,  7.7624,  8.1505,  8.5386,
                8.9268,  9.3149,  9.7030],
                [-9.7030, -9.3149, -8.9268, -8.5386, -8.1505, -7.7624, -7.3743, -6.9862,
                -6.5980, -6.2099, -5.8218, -5.4337, -5.0456, -4.6574, -4.2693, -3.8812,
                -3.4931, -3.1050, -2.7168, -2.3287, -1.9406, -1.5525, -1.1644, -0.7762,
                -0.3881,  0.0000,  0.3881,  0.7762,  1.1644,  1.5525,  1.9406,  2.3287,
                2.7168,  3.1050,  3.4931,  3.8812,  4.2693,  4.6574,  5.0456,  5.4337,
                5.8218,  6.2099,  6.5980,  6.9862,  7.3743,  7.7624,  8.1505,  8.5386,
                8.9268,  9.3149,  9.7030]
            ]).cuda()
            expected = torch.tensor([
                [0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202, 0.0199, 0.0202,
                0.0208, 0.0201, 0.0195, 0.0201, 0.0201, 0.0198, 0.0203, 0.0204, 0.0203,
                0.0200, 0.0203, 0.0199, 0.0197, 0.0205, 0.0208, 0.0197, 0.0214, 0.0204,
                0.0206, 0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204, 0.0203,
                0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204, 0.0204, 0.0205, 0.0208,
                0.0200, 0.0197, 0.0204, 0.0207, 0.0200, 0.0049],
                [0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202, 0.0199, 0.0202,
                0.0208, 0.0202, 0.0196, 0.0201, 0.0201, 0.0198, 0.0203, 0.0204, 0.0203,
                0.0200, 0.0203, 0.0199, 0.0197, 0.0205, 0.0208, 0.0197, 0.0214, 0.0204,
                0.0206, 0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204, 0.0203,
                0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204, 0.0204, 0.0205, 0.0208,
                0.0200, 0.0197, 0.0204, 0.0206, 0.0200, 0.0049],
                [0.0049, 0.0198, 0.0204, 0.0202, 0.0198, 0.0202, 0.0202, 0.0199, 0.0202,
                0.0208, 0.0202, 0.0196, 0.0202, 0.0201, 0.0198, 0.0203, 0.0204, 0.0203,
                0.0200, 0.0203, 0.0199, 0.0197, 0.0204, 0.0208, 0.0198, 0.0214, 0.0204,
                0.0206, 0.0203, 0.0199, 0.0199, 0.0206, 0.0198, 0.0201, 0.0204, 0.0203,
                0.0204, 0.0206, 0.0201, 0.0197, 0.0199, 0.0204, 0.0204, 0.0205, 0.0208,
                0.0200, 0.0197, 0.0204, 0.0206, 0.0200, 0.0049]
            ])
            tt.assert_almost_equal(q.project(dist, support).cpu(), expected.cpu(), decimal=3)
class TestQDist(unittest.TestCase):
    def setUp(self):
        torch.manual_seed(2)
        self.model = nn.Sequential(nn.Linear(STATE_DIM, ACTIONS * ATOMS))
        optimizer = torch.optim.SGD(self.model.parameters(), lr=0.1)
        self.q = QDist(self.model, optimizer, ACTIONS, ATOMS, V_MIN, V_MAX)

    def test_atoms(self):
        tt.assert_almost_equal(self.q.atoms, torch.tensor([-2, -1, 0, 1, 2]))

    def test_q_values(self):
        states = State(torch.randn((3, STATE_DIM)))
        probs = self.q(states)
        self.assertEqual(probs.shape, (3, ACTIONS, ATOMS))
        tt.assert_almost_equal(
            probs.sum(dim=2),
            torch.tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]),
            decimal=3,
        )
        tt.assert_almost_equal(
            probs,
            torch.tensor([
                [
                    [0.2065, 0.1045, 0.1542, 0.2834, 0.2513],
                    [0.3903, 0.2471, 0.0360, 0.1733, 0.1533],
                ],
                [
                    [0.1966, 0.1299, 0.1431, 0.3167, 0.2137],
                    [0.3190, 0.2471, 0.0534, 0.1424, 0.2380],
                ],
                [
                    [0.1427, 0.2486, 0.0946, 0.4112, 0.1029],
                    [0.0819, 0.1320, 0.1203, 0.0373, 0.6285],
                ],
            ]),
            decimal=3,
        )

    def test_single_q_values(self):
        states = State(torch.randn((3, STATE_DIM)))
        actions = torch.tensor([0, 1, 0])
        probs = self.q(states, actions)
        self.assertEqual(probs.shape, (3, ATOMS))
        tt.assert_almost_equal(probs.sum(dim=1),
                               torch.tensor([1.0, 1.0, 1.0]),
                               decimal=3)
        tt.assert_almost_equal(
            probs,
            torch.tensor([
                [0.2065, 0.1045, 0.1542, 0.2834, 0.2513],
                [0.3190, 0.2471, 0.0534, 0.1424, 0.2380],
                [0.1427, 0.2486, 0.0946, 0.4112, 0.1029],
            ]),
            decimal=3,
        )

    def test_done(self):
        states = State(torch.randn((3, STATE_DIM)),
                       mask=torch.tensor([1, 0, 1]))
        probs = self.q(states)
        self.assertEqual(probs.shape, (3, ACTIONS, ATOMS))
        tt.assert_almost_equal(
            probs.sum(dim=2),
            torch.tensor([[1.0, 1.0], [1.0, 1.0], [1.0, 1.0]]),
            decimal=3,
        )
        tt.assert_almost_equal(
            probs,
            torch.tensor([
                [
                    [0.2065, 0.1045, 0.1542, 0.2834, 0.2513],
                    [0.3903, 0.2471, 0.0360, 0.1733, 0.1533],
                ],
                [[0, 0, 1, 0, 0], [0, 0, 1, 0, 0]],
                [
                    [0.1427, 0.2486, 0.0946, 0.4112, 0.1029],
                    [0.0819, 0.1320, 0.1203, 0.0373, 0.6285],
                ],
            ]),
            decimal=3,
        )

    def test_reinforce(self):
        states = State(torch.randn((3, STATE_DIM)))
        actions = torch.tensor([0, 1, 0])
        original_probs = self.q(states, actions)
        tt.assert_almost_equal(
            original_probs,
            torch.tensor([
                [0.2065, 0.1045, 0.1542, 0.2834, 0.2513],
                [0.3190, 0.2471, 0.0534, 0.1424, 0.2380],
                [0.1427, 0.2486, 0.0946, 0.4112, 0.1029],
            ]),
            decimal=3,
        )

        target_dists = torch.tensor([[0, 0, 1, 0, 0], [0, 0, 0, 0, 1],
                                     [0, 1, 0, 0, 0]]).float()
        self.q.reinforce(target_dists)

        new_probs = self.q(states, actions)
        tt.assert_almost_equal(torch.sign(new_probs - original_probs),
                               torch.sign(target_dists - 0.5))