Exemplo n.º 1
0
    def update(self):
        """
        Iterate through the transition queue and make NEC updates
        """
        for t in range(len(self.transition_queue)):
            transition = self.transition_queue[t]
            state = Variable(Tensor(transition.state)).unsqueeze(0)
            action = transition.action
            state_embedding = self.embedding_network(move_to_gpu(state))
            dnd = self.dnd_list[action]

            Q_N = move_to_gpu(self.Q_lookahead(t))
            embedding_index = dnd.get_index(state_embedding)
            # print(embedding_index)
            if embedding_index is None:
                dnd.insert(state_embedding.detach(), Q_N.detach().unsqueeze(0))
            else:
                Q = self.Q_update(dnd.values[embedding_index], Q_N)
                dnd.update(Q.detach(), embedding_index)
            self.replay_memory.push(transition.state, action,
                                    move_to_gpu(Q_N.detach()))

        [dnd.commit_insert() for dnd in self.dnd_list]

        for t in range(len(self.transition_queue)):
            if t % self.update_period == 0 or t == len(self.transition_queue) - 1:
                # Train on random mini-batch from self.replay_memory
                batch = self.replay_memory.sample(self.batch_size)
                actual = torch.cat([sample.Q_N for sample in batch])
                predicted = torch.cat([self.dnd_list[sample.action].lookup(self.embedding_network(move_to_gpu(
                    Variable(Tensor(sample.state))).unsqueeze(0)), update_flag=True) for sample in batch])
                loss = torch.dist(actual, move_to_gpu(predicted))
                self.optimizer.zero_grad()
                loss.backward()
                self.optimizer.step()
                [dnd.update_params() for dnd in self.dnd_list]

        # Clear out transition queue
        self.transition_queue = []
Exemplo n.º 2
0
 def Q_lookahead(self, t, warmup=False):
   """
   Return the N-step Q-value lookahead from time t in the transition queue
   """
   if warmup or len(self.transition_queue) <= t + self.lookahead_horizon:
     lookahead = discount(
         [transition.reward for transition in self.transition_queue[t:]], self.gamma)[0]
     return Variable(Tensor([lookahead]))
   else:
     lookahead = discount(
         [transition.reward for transition in self.transition_queue[t:t+self.lookahead_horizon]], self.gamma)[0]
     state = self.transition_queue[t+self.lookahead_horizon].state
     state_embedding = self.embedding_network(
         move_to_gpu(Variable(Tensor(state)).unsqueeze(0)))
     return self.gamma ** self.lookahead_horizon * torch.cat([dnd.lookup(state_embedding) for dnd in self.dnd_list]).max() + lookahead
Exemplo n.º 3
0
  def warmup(self):
    """
    Warmup the DND with values from an episode with a random policy
    """
    state = self.env.reset()
    total_reward = 0
    done = False
    while not done:
      action = random.randint(0, self.env.action_space.n - 1)
      next_state, reward, done, _ = self.env.step(action)
      total_reward += reward
      self.transition_queue.append(Transition(state, action, reward))
      state = next_state

    for t in range(len(self.transition_queue)):
      transition = self.transition_queue[t]
      state = Variable(Tensor(transition.state)).unsqueeze(0)
      action = transition.action
      state_embedding = self.embedding_network(move_to_gpu(state))
      dnd = self.dnd_list[action]

      Q_N = move_to_gpu(self.Q_lookahead(t, True))
      if dnd.keys_to_be_inserted is None and dnd.keys is None:
        dnd.insert(state_embedding, Q_N.detach().unsqueeze(0))
      else:
        embedding_index = dnd.get_index(state_embedding)
        if embedding_index is None:
          dnd.insert(state_embedding.detach(), Q_N.detach().unsqueeze(0))
        else:
          Q = self.Q_update(dnd.values[embedding_index], Q_N)
          dnd.update(Q.detach(), embedding_index)
      self.replay_memory.push(transition.state, action, Q_N)
    [dnd.commit_insert() for dnd in self.dnd_list]
    # Clear out transition queue
    self.transition_queue = []
    return total_reward
Exemplo n.º 4
0
 def episode(self):
   """
   Train an NEC agent for a single episode
   Interact with environment on-policy and append all (state, action, reward) transitions to transition queue
   Call update at the end of every episode
   """
   if self.epsilon > self.final_epsilon:
     self.epsilon = self.epsilon * self.epsilon_decay
   state = self.env.reset()
   total_reward = 0
   done = False
   while not done:
     state_embedding = self.embedding_network(
         move_to_gpu(Variable(Tensor(state)).unsqueeze(0)))
     action = self.choose_action(state_embedding)
     next_state, reward, done, _ = self.env.step(action)
     self.transition_queue.append(Transition(state, action, reward))
     total_reward += reward
     state = next_state
   self.update()
   return total_reward
Exemplo n.º 5
0
    def Q_lookahead(self, t, warmup=False):
        """
        Return the N-step Q-value lookahead from time t in the transition queue
        """
        if warmup or len(self.transition_queue) <= t + self.lookahead_horizon:
            lookahead = discount(
                [transition.reward for transition in self.transition_queue[t:]], self.gamma)[0]
            # print('Q_lookahead cond1=TRUE')
            return Variable(Tensor([lookahead]))
        else:
            # 这个地方计算有问题, 只要是这个返回的, 返回的就是标量

            # 计算从 t 到 t + lookahead_horizon 的折扣后的累积值
            lookahead = discount(
                [transition.reward for transition in self.transition_queue[t:t + self.lookahead_horizon]], self.gamma)[
                0]
            # 当前的状态 transition 在 t + lookahead_horizon 的状态
            state = self.transition_queue[t + self.lookahead_horizon].state
            state_embedding = self.embedding_network(move_to_gpu(Variable(Tensor(state)).unsqueeze(0)))
            # print('Q_lookahead cond1=FALSE')
            # 论文的公式(3)
            c = torch.cat([dnd.lookup(state_embedding) for dnd in self.dnd_list])
            # print('c=', c)
            return self.gamma ** self.lookahead_horizon * c.max().unsqueeze(0) + lookahead