def update(self): """ Iterate through the transition queue and make NEC updates """ for t in range(len(self.transition_queue)): transition = self.transition_queue[t] state = Variable(Tensor(transition.state)).unsqueeze(0) action = transition.action state_embedding = self.embedding_network(move_to_gpu(state)) dnd = self.dnd_list[action] Q_N = move_to_gpu(self.Q_lookahead(t)) embedding_index = dnd.get_index(state_embedding) # print(embedding_index) if embedding_index is None: dnd.insert(state_embedding.detach(), Q_N.detach().unsqueeze(0)) else: Q = self.Q_update(dnd.values[embedding_index], Q_N) dnd.update(Q.detach(), embedding_index) self.replay_memory.push(transition.state, action, move_to_gpu(Q_N.detach())) [dnd.commit_insert() for dnd in self.dnd_list] for t in range(len(self.transition_queue)): if t % self.update_period == 0 or t == len(self.transition_queue) - 1: # Train on random mini-batch from self.replay_memory batch = self.replay_memory.sample(self.batch_size) actual = torch.cat([sample.Q_N for sample in batch]) predicted = torch.cat([self.dnd_list[sample.action].lookup(self.embedding_network(move_to_gpu( Variable(Tensor(sample.state))).unsqueeze(0)), update_flag=True) for sample in batch]) loss = torch.dist(actual, move_to_gpu(predicted)) self.optimizer.zero_grad() loss.backward() self.optimizer.step() [dnd.update_params() for dnd in self.dnd_list] # Clear out transition queue self.transition_queue = []
def Q_lookahead(self, t, warmup=False): """ Return the N-step Q-value lookahead from time t in the transition queue """ if warmup or len(self.transition_queue) <= t + self.lookahead_horizon: lookahead = discount( [transition.reward for transition in self.transition_queue[t:]], self.gamma)[0] return Variable(Tensor([lookahead])) else: lookahead = discount( [transition.reward for transition in self.transition_queue[t:t+self.lookahead_horizon]], self.gamma)[0] state = self.transition_queue[t+self.lookahead_horizon].state state_embedding = self.embedding_network( move_to_gpu(Variable(Tensor(state)).unsqueeze(0))) return self.gamma ** self.lookahead_horizon * torch.cat([dnd.lookup(state_embedding) for dnd in self.dnd_list]).max() + lookahead
def warmup(self): """ Warmup the DND with values from an episode with a random policy """ state = self.env.reset() total_reward = 0 done = False while not done: action = random.randint(0, self.env.action_space.n - 1) next_state, reward, done, _ = self.env.step(action) total_reward += reward self.transition_queue.append(Transition(state, action, reward)) state = next_state for t in range(len(self.transition_queue)): transition = self.transition_queue[t] state = Variable(Tensor(transition.state)).unsqueeze(0) action = transition.action state_embedding = self.embedding_network(move_to_gpu(state)) dnd = self.dnd_list[action] Q_N = move_to_gpu(self.Q_lookahead(t, True)) if dnd.keys_to_be_inserted is None and dnd.keys is None: dnd.insert(state_embedding, Q_N.detach().unsqueeze(0)) else: embedding_index = dnd.get_index(state_embedding) if embedding_index is None: dnd.insert(state_embedding.detach(), Q_N.detach().unsqueeze(0)) else: Q = self.Q_update(dnd.values[embedding_index], Q_N) dnd.update(Q.detach(), embedding_index) self.replay_memory.push(transition.state, action, Q_N) [dnd.commit_insert() for dnd in self.dnd_list] # Clear out transition queue self.transition_queue = [] return total_reward
def episode(self): """ Train an NEC agent for a single episode Interact with environment on-policy and append all (state, action, reward) transitions to transition queue Call update at the end of every episode """ if self.epsilon > self.final_epsilon: self.epsilon = self.epsilon * self.epsilon_decay state = self.env.reset() total_reward = 0 done = False while not done: state_embedding = self.embedding_network( move_to_gpu(Variable(Tensor(state)).unsqueeze(0))) action = self.choose_action(state_embedding) next_state, reward, done, _ = self.env.step(action) self.transition_queue.append(Transition(state, action, reward)) total_reward += reward state = next_state self.update() return total_reward
def Q_lookahead(self, t, warmup=False): """ Return the N-step Q-value lookahead from time t in the transition queue """ if warmup or len(self.transition_queue) <= t + self.lookahead_horizon: lookahead = discount( [transition.reward for transition in self.transition_queue[t:]], self.gamma)[0] # print('Q_lookahead cond1=TRUE') return Variable(Tensor([lookahead])) else: # 这个地方计算有问题, 只要是这个返回的, 返回的就是标量 # 计算从 t 到 t + lookahead_horizon 的折扣后的累积值 lookahead = discount( [transition.reward for transition in self.transition_queue[t:t + self.lookahead_horizon]], self.gamma)[ 0] # 当前的状态 transition 在 t + lookahead_horizon 的状态 state = self.transition_queue[t + self.lookahead_horizon].state state_embedding = self.embedding_network(move_to_gpu(Variable(Tensor(state)).unsqueeze(0))) # print('Q_lookahead cond1=FALSE') # 论文的公式(3) c = torch.cat([dnd.lookup(state_embedding) for dnd in self.dnd_list]) # print('c=', c) return self.gamma ** self.lookahead_horizon * c.max().unsqueeze(0) + lookahead