class NECAgent(BaseAgent):
    def __init__(self, action_space, cmdl):
        BaseAgent.__init__(self, action_space)

        self.name = "NEC_agent"
        self.cmdl = cmdl
        self.dtype = TorchTypes()
        self.slow_lr = slow_lr = cmdl.slow_lr
        self.fast_lr = fast_lr = cmdl.fast_lr
        dnd = cmdl.dnd

        # Feature extractor and embedding size
        FeatureExtractor = get_estimator(cmdl.estimator)
        state_dim = (1, 24) if not cmdl.rescale else (1, 84)
        if dnd.linear_projection:
            self.conv = FeatureExtractor(state_dim, dnd.linear_projection)
        elif dnd.linear_projection is False:
            self.conv = FeatureExtractor(state_dim, None)
        embedding_size = self.conv.get_embedding_size()

        # DNDs, Memory, N-step buffer
        self.dnds = [
            DND(dnd.size, embedding_size, dnd.knn_no)
            for i in range(self.action_no)
        ]
        self.replay_memory = ReplayMemory(capacity=cmdl.experience_replay)
        self.N_step = self.cmdl.n_horizon
        self.N_buff = []

        self.optimizer = torch.optim.Adam(self.conv.parameters(), lr=slow_lr)
        self.optimizer.zero_grad()
        self.update_q = update_rule(fast_lr)

        # Temp data, flags, stats, misc
        self._key_tmp = None
        self.knn_ready = False
        self.initial_val = 0.1
        self.max_q = -math.inf

    def evaluate_policy(self, state):
        """ Policy Evaluation.

            Performs a forward operation through the neural net feature
            extractor and uses the resulting representation to compute the k
            nearest neighbors in each of the DNDs associated with each action.

            Returs the action with the highest weighted value between the
            k nearest neighbors.
        """
        state = torch.from_numpy(state).unsqueeze(0).unsqueeze(0)
        h = self.conv(Variable(state, volatile=True))
        self._key_tmp = h

        # corner case, randomly fill the buffers so that we can perform knn.
        if not self.knn_ready:
            return self._heat_up_dnd(h.data)

        # query each DND for q values and pick the largest one.
        if np.random.uniform() > self.cmdl.epsilon:
            v, action = self._query_dnds(h)
            self.max_q = v if self.max_q < v else self.max_q
            return action
        else:
            return self.action_space.sample()

    def improve_policy(self, _state, _action, reward, state, done):
        """ Policy Evaluation.
        """
        self.N_buff.append((_state, self._key_tmp, _action, reward))

        R = 0
        if self.knn_ready and ((len(self.N_buff) == self.N_step) or done):
            if not done:
                # compute Q(t + N)
                state = torch.from_numpy(state).unsqueeze(0).unsqueeze(0)
                h = self.conv(Variable(state, volatile=True))
                R, _ = self._query_dnds(h)

            for i in range(len(self.N_buff) - 1, -1, -1):

                s = self.N_buff[i][0]
                h = self.N_buff[i][1]
                a = self.N_buff[i][2]
                R = self.N_buff[i][3] + 0.99 * R

                # write to DND
                self.dnds[a].write(h.data, R, self.update_q)
                # print("%3d, %3d, %3d  |  %0.3f" % (self.step_cnt, i, a, R))

                # append to experience replay
                self.replay_memory.push(s, a, R)

            self.N_buff.clear()

            for dnd in self.dnds:
                dnd.rebuild_tree()

        if self.cmdl.update_freq is False:
            return

        if (self.step_cnt % self.cmdl.update_freq
                == 0) and (len(self.replay_memory) > self.cmdl.batch_size):
            # get batch of transitions
            transitions = self.replay_memory.sample(self.cmdl.batch_size)
            batch = self._batch2torch(transitions)
            # compute gradients
            self._accumulate_gradient(*batch)
            # optimize
            self._update_model()

    def _query_dnds(self, h):
        q_vals = torch.FloatTensor(self.action_no, 1).fill_(0)
        for i, dnd in enumerate(self.dnds):
            q_vals[i] = dnd.lookup(h)
        return q_vals.max(0)[0][0, 0], q_vals.max(0)[1][0, 0]

    def _accumulate_gradient(self, states, actions, returns):
        """ Compute gradient
            v=Q(s,a), return = QN(s,a)
        """
        states = Variable(states)
        actions = Variable(actions)
        returns = Variable(returns)

        # Compute Q(s, a)
        features = self.conv(states)

        v_variables = []
        for i in range(self.cmdl.batch_size):
            act = actions[i].data[0]
            v = self.dnds[act].lookup(features[i].unsqueeze(0), training=True)
            v_variables.append(v)

        q_values = torch.stack(v_variables)

        loss = F.smooth_l1_loss(q_values, returns)
        loss.data.clamp(-1, 1)

        # Accumulate gradients
        loss.backward()

    def _update_model(self):
        for param in self.conv.parameters():
            param.grad.data.clamp(-1, 1)

        self.optimizer.step()
        self.optimizer.zero_grad()

    def _heat_up_dnd(self, h):
        # fill the dnds with knn_no * (action_no + 1)
        action = np.random.randint(self.action_no)
        self.dnds[action].write(h, self.initial_val, self.update_q)
        self.knn_ready = self.step_cnt >= 2 * self.cmdl.dnd.knn_no * \
            (self.action_space.n + 1)
        if self.knn_ready:
            for dnd in self.dnds:
                dnd.rebuild_tree()
        return action

    def _batch2torch(self, batch, batch_sz=None):
        """ List of Transitions to List of torch states, actions, rewards.

            From a batch of transitions (s0, a0, Rt)
            get a batch of the form state=(s0,s1...), action=(a1,a2...),
            Rt=(rt1,rt2...)
            Inefficient. Adds 1.5s~2s for 20,000 steps with 32 agents.
        """
        batch_sz = len(batch) if batch_sz is None else batch_sz
        batch = Transition(*zip(*batch))

        states = [torch.from_numpy(s).unsqueeze(0) for s in batch.state]
        state_batch = torch.stack(states).type(self.dtype.FloatTensor)
        action_batch = self.dtype.LongTensor(batch.action)
        rt_batch = self.dtype.FloatTensor(batch.Rt)
        return [state_batch, action_batch, rt_batch]

    def display_model_stats(self):
        param_abs_mean = 0
        grad_abs_mean = 0
        n_params = 0
        for p in self.conv.parameters():
            param_abs_mean += p.data.abs().sum()
            if p.grad:
                grad_abs_mean += p.grad.data.abs().sum()
            n_params += p.data.nelement()

        print("[NEC_agent] step=%6d, Wm: %.9f" %
              (self.step_cnt, param_abs_mean / n_params))
        print("[NEC_agent] maxQ=%.3f " % (self.max_q))
        for i, dnd in enumerate(self.dnds):
            print("[DND] M=%d, DND.count=%d" % (len(dnd.M), dnd.count))
        for i, dnd in enumerate(self.dnds):
            print("[DND] old=%6d, new=%6d" % (dnd.old, dnd.new))
Exemplo n.º 2
0
class DQNAgent(BaseAgent):
    def __init__(self, action_space, cmdl, is_training=True):
        BaseAgent.__init__(self, action_space, is_training)
        self.name = "DQN_agent"
        self.cmdl = cmdl
        eps = self.cmdl.epsilon
        e_steps = self.cmdl.epsilon_steps

        self.policy = policy = get_model(cmdl.estimator, 1, cmdl.hist_len,
                                         self.action_no, cmdl.hidden_size)
        self.target = target = get_model(cmdl.estimator, 1, cmdl.hist_len,
                                         self.action_no, cmdl.hidden_size)
        if self.cmdl.cuda:
            self.policy.cuda()
            self.target.cuda()
        self.policy_evaluation = DQNEvaluation(policy)
        self.policy_improvement = DQNImprovement(policy, target, cmdl)
        self.exploration = get_epsilon_schedule("linear", eps, 0.05, e_steps)
        self.replay_memory = ReplayMemory(capacity=cmdl.experience_replay)

        self.dtype = TorchTypes(cmdl.cuda)
        self.max_q = -1000

    def evaluate_policy(self, state):
        if self.is_training:
            self.epsilon = next(self.exploration)
        else:
            self.epsilon = 0.05

        if self.epsilon < uniform():
            state = self._frame2torch(state)
            qval, action = self.policy_evaluation.get_action(state)
            # print(qval, action)
            self.max_q = max(qval, self.max_q)
            return action
        else:
            return self.actions.sample()

    def improve_policy(self, _s, _a, r, s, done):
        self.replay_memory.push(_s, _a, s, r, done)

        if len(self.replay_memory) < self.cmdl.batch_size:
            return

        if (self.step_cnt % self.cmdl.update_freq
                == 0) and (len(self.replay_memory) > self.cmdl.batch_size):
            # get batch of transitions
            transitions = self.replay_memory.sample(self.cmdl.batch_size)
            batch = self._batch2torch(transitions)
            # compute gradients
            self.policy_improvement.accumulate_gradient(*batch)
            self.policy_improvement.update_model()

        if self.step_cnt % self.cmdl.target_update_freq == 0:
            self.policy_improvement.update_target_net()

    def display_model_stats(self):
        self.policy_improvement.get_model_stats()
        print("MaxQ=%2.2f.  MemSz=%5d.  Epsilon=%.2f." %
              (self.max_q, len(self.replay_memory), self.epsilon))

    def _frame2torch(self, s):
        state = torch.from_numpy(s).unsqueeze(0).unsqueeze(0)
        return state.type(self.dtype.FloatTensor)

    def _batch2torch(self, batch, batch_sz=None):
        """ List of Transitions to List of torch states, actions, rewards.
            From a batch of transitions (s0, a0, Rt)
            get a batch of the form state=(s0,s1...), action=(a1,a2...),
            Rt=(rt1,rt2...)
        """
        batch_sz = len(batch) if batch_sz is None else batch_sz
        batch = Transition(*zip(*batch))
        # print("[%s] Batch len=%d" % (self.name, batch_sz))

        states = [torch.from_numpy(s).unsqueeze(0) for s in batch.state]
        states_ = [torch.from_numpy(s).unsqueeze(0) for s in batch.state_]

        state_batch = torch.stack(states).type(self.dtype.FloatTensor)
        action_batch = self.dtype.LongTensor(batch.action)
        reward_batch = self.dtype.FloatTensor(batch.reward)
        next_state_batch = torch.stack(states_).type(self.dtype.FloatTensor)

        # Compute a mask for terminal next states
        # [True, False, False] -> [1, 0, 0]::ByteTensor
        mask = 1 - self.dtype.ByteTensor(batch.done)

        return [
            batch_sz, state_batch, action_batch, reward_batch,
            next_state_batch, mask
        ]