Example #1
0
def test_memory():

    memory = PrioritizedMemory(10)
    memory.add(15, 1, 2, 3, 4, 5)
    memory.add(10, 4, 5, 6, 5, 2)

    indexes, transitions = zip(*memory.sample(2))

    assert indexes == (9, 10)
    assert transitions == (Transition(state=1,
                                      action=2,
                                      reward=3,
                                      next_state=4,
                                      terminal=5),
                           Transition(state=4,
                                      action=5,
                                      reward=6,
                                      next_state=5,
                                      terminal=2))
    """ Example of batch creation """
    assert Transition(*zip(*transitions)) == Transition(state=(1, 4),
                                                        action=(2, 5),
                                                        reward=(3, 6),
                                                        next_state=(4, 5),
                                                        terminal=(5, 2))
Example #2
0
class DoubleDQN():
    """
    From Deep Reinforcement Learning with Double Q-learning
    at https://arxiv.org/abs/1509.06461
    """
    def __init__(self, DQN, parameters=DQNParameters()):
        """
        DQN: The DQN used to estimate the reward
        parameters: The parameters!
        """

        self.on_loss_computed = Signal()
        self.device = torch.device(
            "cuda" if torch.cuda.is_available() else "cpu")

        self.DQN = DQN.to(self.device).train()

        self.frozen_DQN = copy.deepcopy(self.DQN).eval()
        for param in self.frozen_DQN.parameters():
            param.requires_grad = False
        self._update_frozen()

        self.memory = PrioritizedMemory(parameters.capacity)

        self.optimizer = optim.RMSprop(self.DQN.parameters(), lr=parameters.lr)
        self.parameters = parameters

        self.it_s_replay_time = generator_true_every(1)
        self.it_s_update_frozen_time = generator_true_every(
            self.parameters.frozen_steps)

        self.it_s_action_debug_time = generator_true_every(1000)

    def _update_frozen(self):
        """
        Let it go, let it go
        I am one with the wind and sky
        Let it go, let it go
        You'll never see me cry
        Here I stand and here I stay
        Let the storm rage on
        """
        self.frozen_DQN.load_state_dict(self.DQN.state_dict())

    def select_action(self, state):
        """ Return the selected action """
        with torch.no_grad():
            values = self.DQN(torch.FloatTensor([state]).to(
                self.device)).cpu().data.numpy()[0]
            if len(self.memory) > self.parameters.waiting_time:
                selected_action = numpy.argmax(values)
                if next(self.it_s_action_debug_time):
                    print(selected_action, values)
            else:
                selected_action = numpy.random.randint(len(values))

            return selected_action

    def observe(self, state, action, reward, next_state, is_terminal):
        """
        Observe an experience tuple (state, action, reward, next_state, is_terminal)
        """

        if self.parameters.clipping is not None:  # Clip the reward
            reward = numpy.clip(reward, -self.parameters.clipping,
                                self.parameters.clipping)

        self.memory.add(10, state, action, reward, next_state, is_terminal)

        if next(self.it_s_update_frozen_time):
            self._update_frozen()

        if next(self.it_s_replay_time) and len(
                self.memory) > self.parameters.waiting_time:
            self._replay()

    def train(self):
        self.DQN.train()

    def eval(self):
        self.DQN.eval()

    def save(self):
        self.DQN.save_state_dict("model.torch")

    def _replay(self):
        """
        Learn things
        """
        indexes, transitions = zip(
            *self.memory.sample(self.parameters.batch_size))
        # Transpose the batch (see https://stackoverflow.com/a/19343/3343043 for
        # detailed explanation). This converts batch-array of Transitions
        # to Transition of batch-arrays.
        batch = Transition(*zip(*transitions))

        state_values = self.DQN(\
            torch.FloatTensor(batch.state).to(self.device),\
            torch.LongTensor(batch.action).to(self.device).unsqueeze(1)\
        )
        with torch.no_grad():
            expected_state_values = torch.FloatTensor(batch.reward).to(self.device).unsqueeze(1)\
                + self.parameters.gamma ** self.memory.n_step * self.DQN(torch.FloatTensor(batch.next_state).to(self.device)).max(1, True)[0]*(1 - torch.FloatTensor(batch.terminal).to(self.device).unsqueeze(1))

        loss = F.mse_loss(state_values, expected_state_values)  # MSE Loss
        self.on_loss_computed.emit(
            loss.cpu().data.numpy())  # Emit the computed loss

        self.optimizer.zero_grad()
        loss.backward()
        for param in self.DQN.parameters():
            if hasattr(param, "grad") and hasattr(param.grad, "data"):
                param.grad.data.clamp_(-1, 1)
        self.optimizer.step()