class NECAgent(BaseAgent): def __init__(self, action_space, cmdl): BaseAgent.__init__(self, action_space) self.name = "NEC_agent" self.cmdl = cmdl self.dtype = TorchTypes() self.slow_lr = slow_lr = cmdl.slow_lr self.fast_lr = fast_lr = cmdl.fast_lr dnd = cmdl.dnd # Feature extractor and embedding size FeatureExtractor = get_estimator(cmdl.estimator) state_dim = (1, 24) if not cmdl.rescale else (1, 84) if dnd.linear_projection: self.conv = FeatureExtractor(state_dim, dnd.linear_projection) elif dnd.linear_projection is False: self.conv = FeatureExtractor(state_dim, None) embedding_size = self.conv.get_embedding_size() # DNDs, Memory, N-step buffer self.dnds = [ DND(dnd.size, embedding_size, dnd.knn_no) for i in range(self.action_no) ] self.replay_memory = ReplayMemory(capacity=cmdl.experience_replay) self.N_step = self.cmdl.n_horizon self.N_buff = [] self.optimizer = torch.optim.Adam(self.conv.parameters(), lr=slow_lr) self.optimizer.zero_grad() self.update_q = update_rule(fast_lr) # Temp data, flags, stats, misc self._key_tmp = None self.knn_ready = False self.initial_val = 0.1 self.max_q = -math.inf def evaluate_policy(self, state): """ Policy Evaluation. Performs a forward operation through the neural net feature extractor and uses the resulting representation to compute the k nearest neighbors in each of the DNDs associated with each action. Returs the action with the highest weighted value between the k nearest neighbors. """ state = torch.from_numpy(state).unsqueeze(0).unsqueeze(0) h = self.conv(Variable(state, volatile=True)) self._key_tmp = h # corner case, randomly fill the buffers so that we can perform knn. if not self.knn_ready: return self._heat_up_dnd(h.data) # query each DND for q values and pick the largest one. if np.random.uniform() > self.cmdl.epsilon: v, action = self._query_dnds(h) self.max_q = v if self.max_q < v else self.max_q return action else: return self.action_space.sample() def improve_policy(self, _state, _action, reward, state, done): """ Policy Evaluation. """ self.N_buff.append((_state, self._key_tmp, _action, reward)) R = 0 if self.knn_ready and ((len(self.N_buff) == self.N_step) or done): if not done: # compute Q(t + N) state = torch.from_numpy(state).unsqueeze(0).unsqueeze(0) h = self.conv(Variable(state, volatile=True)) R, _ = self._query_dnds(h) for i in range(len(self.N_buff) - 1, -1, -1): s = self.N_buff[i][0] h = self.N_buff[i][1] a = self.N_buff[i][2] R = self.N_buff[i][3] + 0.99 * R # write to DND self.dnds[a].write(h.data, R, self.update_q) # print("%3d, %3d, %3d | %0.3f" % (self.step_cnt, i, a, R)) # append to experience replay self.replay_memory.push(s, a, R) self.N_buff.clear() for dnd in self.dnds: dnd.rebuild_tree() if self.cmdl.update_freq is False: return if (self.step_cnt % self.cmdl.update_freq == 0) and (len(self.replay_memory) > self.cmdl.batch_size): # get batch of transitions transitions = self.replay_memory.sample(self.cmdl.batch_size) batch = self._batch2torch(transitions) # compute gradients self._accumulate_gradient(*batch) # optimize self._update_model() def _query_dnds(self, h): q_vals = torch.FloatTensor(self.action_no, 1).fill_(0) for i, dnd in enumerate(self.dnds): q_vals[i] = dnd.lookup(h) return q_vals.max(0)[0][0, 0], q_vals.max(0)[1][0, 0] def _accumulate_gradient(self, states, actions, returns): """ Compute gradient v=Q(s,a), return = QN(s,a) """ states = Variable(states) actions = Variable(actions) returns = Variable(returns) # Compute Q(s, a) features = self.conv(states) v_variables = [] for i in range(self.cmdl.batch_size): act = actions[i].data[0] v = self.dnds[act].lookup(features[i].unsqueeze(0), training=True) v_variables.append(v) q_values = torch.stack(v_variables) loss = F.smooth_l1_loss(q_values, returns) loss.data.clamp(-1, 1) # Accumulate gradients loss.backward() def _update_model(self): for param in self.conv.parameters(): param.grad.data.clamp(-1, 1) self.optimizer.step() self.optimizer.zero_grad() def _heat_up_dnd(self, h): # fill the dnds with knn_no * (action_no + 1) action = np.random.randint(self.action_no) self.dnds[action].write(h, self.initial_val, self.update_q) self.knn_ready = self.step_cnt >= 2 * self.cmdl.dnd.knn_no * \ (self.action_space.n + 1) if self.knn_ready: for dnd in self.dnds: dnd.rebuild_tree() return action def _batch2torch(self, batch, batch_sz=None): """ List of Transitions to List of torch states, actions, rewards. From a batch of transitions (s0, a0, Rt) get a batch of the form state=(s0,s1...), action=(a1,a2...), Rt=(rt1,rt2...) Inefficient. Adds 1.5s~2s for 20,000 steps with 32 agents. """ batch_sz = len(batch) if batch_sz is None else batch_sz batch = Transition(*zip(*batch)) states = [torch.from_numpy(s).unsqueeze(0) for s in batch.state] state_batch = torch.stack(states).type(self.dtype.FloatTensor) action_batch = self.dtype.LongTensor(batch.action) rt_batch = self.dtype.FloatTensor(batch.Rt) return [state_batch, action_batch, rt_batch] def display_model_stats(self): param_abs_mean = 0 grad_abs_mean = 0 n_params = 0 for p in self.conv.parameters(): param_abs_mean += p.data.abs().sum() if p.grad: grad_abs_mean += p.grad.data.abs().sum() n_params += p.data.nelement() print("[NEC_agent] step=%6d, Wm: %.9f" % (self.step_cnt, param_abs_mean / n_params)) print("[NEC_agent] maxQ=%.3f " % (self.max_q)) for i, dnd in enumerate(self.dnds): print("[DND] M=%d, DND.count=%d" % (len(dnd.M), dnd.count)) for i, dnd in enumerate(self.dnds): print("[DND] old=%6d, new=%6d" % (dnd.old, dnd.new))
class DQNAgent(BaseAgent): def __init__(self, action_space, cmdl, is_training=True): BaseAgent.__init__(self, action_space, is_training) self.name = "DQN_agent" self.cmdl = cmdl eps = self.cmdl.epsilon e_steps = self.cmdl.epsilon_steps self.policy = policy = get_model(cmdl.estimator, 1, cmdl.hist_len, self.action_no, cmdl.hidden_size) self.target = target = get_model(cmdl.estimator, 1, cmdl.hist_len, self.action_no, cmdl.hidden_size) if self.cmdl.cuda: self.policy.cuda() self.target.cuda() self.policy_evaluation = DQNEvaluation(policy) self.policy_improvement = DQNImprovement(policy, target, cmdl) self.exploration = get_epsilon_schedule("linear", eps, 0.05, e_steps) self.replay_memory = ReplayMemory(capacity=cmdl.experience_replay) self.dtype = TorchTypes(cmdl.cuda) self.max_q = -1000 def evaluate_policy(self, state): if self.is_training: self.epsilon = next(self.exploration) else: self.epsilon = 0.05 if self.epsilon < uniform(): state = self._frame2torch(state) qval, action = self.policy_evaluation.get_action(state) # print(qval, action) self.max_q = max(qval, self.max_q) return action else: return self.actions.sample() def improve_policy(self, _s, _a, r, s, done): self.replay_memory.push(_s, _a, s, r, done) if len(self.replay_memory) < self.cmdl.batch_size: return if (self.step_cnt % self.cmdl.update_freq == 0) and (len(self.replay_memory) > self.cmdl.batch_size): # get batch of transitions transitions = self.replay_memory.sample(self.cmdl.batch_size) batch = self._batch2torch(transitions) # compute gradients self.policy_improvement.accumulate_gradient(*batch) self.policy_improvement.update_model() if self.step_cnt % self.cmdl.target_update_freq == 0: self.policy_improvement.update_target_net() def display_model_stats(self): self.policy_improvement.get_model_stats() print("MaxQ=%2.2f. MemSz=%5d. Epsilon=%.2f." % (self.max_q, len(self.replay_memory), self.epsilon)) def _frame2torch(self, s): state = torch.from_numpy(s).unsqueeze(0).unsqueeze(0) return state.type(self.dtype.FloatTensor) def _batch2torch(self, batch, batch_sz=None): """ List of Transitions to List of torch states, actions, rewards. From a batch of transitions (s0, a0, Rt) get a batch of the form state=(s0,s1...), action=(a1,a2...), Rt=(rt1,rt2...) """ batch_sz = len(batch) if batch_sz is None else batch_sz batch = Transition(*zip(*batch)) # print("[%s] Batch len=%d" % (self.name, batch_sz)) states = [torch.from_numpy(s).unsqueeze(0) for s in batch.state] states_ = [torch.from_numpy(s).unsqueeze(0) for s in batch.state_] state_batch = torch.stack(states).type(self.dtype.FloatTensor) action_batch = self.dtype.LongTensor(batch.action) reward_batch = self.dtype.FloatTensor(batch.reward) next_state_batch = torch.stack(states_).type(self.dtype.FloatTensor) # Compute a mask for terminal next states # [True, False, False] -> [1, 0, 0]::ByteTensor mask = 1 - self.dtype.ByteTensor(batch.done) return [ batch_sz, state_batch, action_batch, reward_batch, next_state_batch, mask ]