Example #1
0
class DQNAgent(object):
    def __init__(self, env, args, work_dir):
        self.env = env
        self.args = args
        self.work_dir = work_dir

        self.n_action = self.env.action_space.n
        self.arr_actions = np.arange(self.n_action)
        self.memory = ReplayMemory(self.args.buffer_size, self.args.device)
        self.qNetwork = ValueNetwork(self.n_action,
                                     self.env).to(self.args.device)
        self.targetNetwork = ValueNetwork(self.n_action,
                                          self.env).to(self.args.device)
        self.qNetwork.train()
        self.targetNetwork.eval()
        self.optimizer = optim.RMSprop(self.qNetwork.parameters(),
                                       lr=0.00025,
                                       eps=0.001,
                                       alpha=0.95)
        self.crit = nn.MSELoss()
        self.eps = max(self.args.eps, self.args.eps_min)
        self.eps_delta = (
            self.eps - self.args.eps_min) / self.args.exploration_decay_speed

    def reset(self):
        return torch.cat([preprocess_state(self.env.reset(), self.env)] * 4, 1)

    def select_action(self, state):
        action_prob = np.zeros(self.n_action, np.float32)
        action_prob.fill(self.eps / self.n_action)
        max_q, max_q_index = self.qNetwork(Variable(state.to(
            self.args.device))).data.cpu().max(1)
        action_prob[max_q_index[0]] += 1 - self.eps
        action = np.random.choice(self.arr_actions, p=action_prob)
        next_state, reward, done, _ = self.env.step(action)
        next_state = torch.cat(
            [state.narrow(1, 1, 3),
             preprocess_state(next_state, self.env)], 1)
        self.memory.push(
            (state, torch.LongTensor([int(action)]), torch.Tensor([reward]),
             next_state, torch.Tensor([done])))
        return next_state, reward, done, max_q[0]

    def run(self):
        state = self.reset()
        # init buffer
        for _ in range(self.args.buffer_init_size):
            next_state, _, done, _ = self.select_action(state)
            state = self.reset() if done else next_state

        total_frame = 0
        reward_list = np.zeros(self.args.log_size, np.float32)
        qval_list = np.zeros(self.args.log_size, np.float32)

        start_time = time.time()

        for epi in count():
            reward_list[epi % self.args.log_size] = 0
            qval_list[epi % self.args.log_size] = -1e9
            state = self.reset()
            done = False
            ep_len = 0

            if epi % self.args.save_freq == 0:
                model_file = os.path.join(self.work_dir, 'model.th')
                with open(model_file, 'wb') as f:
                    torch.save(self.qNetwork, f)

            while not done:
                if total_frame % self.args.sync_period == 0:
                    self.targetNetwork.load_state_dict(
                        self.qNetwork.state_dict())

                self.eps = max(self.args.eps_min, self.eps - self.eps_delta)
                next_state, reward, done, qval = self.select_action(state)
                reward_list[epi % self.args.log_size] += reward
                qval_list[epi % self.args.log_size] = max(
                    qval_list[epi % self.args.log_size], qval)
                state = next_state

                total_frame += 1
                ep_len += 1

                if ep_len % self.args.learn_freq == 0:
                    batch_state, batch_action, batch_reward, batch_next_state, batch_done = self.memory.sample(
                        self.args.batch_size)
                    batch_q = self.qNetwork(batch_state).gather(
                        1, batch_action.unsqueeze(1)).squeeze(1)
                    batch_next_q = self.targetNetwork(batch_next_state).detach(
                    ).max(1)[0] * self.args.gamma * (1 - batch_done)
                    loss = self.crit(batch_q, batch_reward + batch_next_q)
                    self.optimizer.zero_grad()

                    loss.backward()
                    self.optimizer.step()

            output_str = 'episode %d frame %d time %.2fs cur_rew %.3f mean_rew %.3f cur_maxq %.3f mean_maxq %.3f' % (
                epi, total_frame, time.time() - start_time,
                reward_list[epi % self.args.log_size], np.mean(reward_list),
                qval_list[epi % self.args.log_size], np.mean(qval_list))
            print(output_str)
            logging.info(output_str)
Example #2
0
class SAC:
    def __init__(self,
                 env,
                 lr=3e-4,
                 gamma=0.99,
                 polyak=5e-3,
                 alpha=0.2,
                 reward_scale=1.0,
                 cuda=True,
                 writer=None):
        state_size = env.observation_space.shape[0]
        action_size = env.action_space.shape[0]
        self.actor = Actor(state_size, action_size)
        self.critic = Critic(state_size, action_size)
        self.target_critic = Critic(state_size, action_size).eval()
        self.actor_optimizer = optim.Adam(self.actor.parameters(), lr=lr)
        self.q1_optimizer = optim.Adam(self.critic.q1.parameters(), lr=lr)
        self.q2_optimizer = optim.Adam(self.critic.q2.parameters(), lr=lr)

        self.target_critic.load_state_dict(self.critic.state_dict())
        for param in self.target_critic.parameters():
            param.requires_grad = False

        self.memory = ReplayMemory()

        self.gamma = gamma
        self.alpha = alpha
        self.polyak = polyak  # Always between 0 and 1, usually close to 1
        self.reward_scale = reward_scale

        self.writer = writer

        self.cuda = cuda
        if cuda:
            self.actor = self.actor.to('cuda')
            self.critic = self.critic.to('cuda')
            self.target_critic = self.target_critic.to('cuda')

    def explore(self, state):
        if self.cuda:
            state = torch.tensor(state).unsqueeze(0).to('cuda', torch.float)
        action, _, _ = self.actor.sample(state)
        # action, _ = self.actor(state)
        return action.cpu().detach().numpy().reshape(-1)

    def exploit(self, state):
        if self.cuda:
            state = torch.tensor(state).unsqueeze(0).to('cuda', torch.float)
        _, _, action = self.actor.sample(state)
        return action.cpu().detach().numpy().reshape(-1)

    def store_step(self, state, action, next_state, reward, terminal):
        state = to_tensor_unsqueeze(state)
        if action.dtype == np.float32:
            action = torch.from_numpy(action)
        next_state = to_tensor_unsqueeze(next_state)
        reward = torch.from_numpy(np.array([reward]).astype(np.float))
        terminal = torch.from_numpy(np.array([terminal]).astype(np.uint8))
        self.memory.push(state, action, next_state, reward, terminal)

    def target_update(self, target_net, net):
        for t, s in zip(target_net.parameters(), net.parameters()):
            # t.data.copy_(t.data * (1.0 - self.polyak) + s.data * self.polyak)
            t.data.mul_(1.0 - self.polyak)
            t.data.add_(self.polyak * s.data)

    def calc_target_q(self, next_states, rewards, terminals):
        with torch.no_grad():
            next_action, entropy, _ = self.actor.sample(
                next_states)  # penalty term
            next_q1, next_q2 = self.target_critic(next_states, next_action)
            next_q = torch.min(next_q1, next_q2) - self.alpha * entropy
        target_q = rewards * self.reward_scale + (
            1. - terminals) * self.gamma * next_q
        return target_q

    def calc_critic_loss(self, states, actions, next_states, rewards,
                         terminals):
        q1, q2 = self.critic(states, actions)
        target_q = self.calc_target_q(next_states, rewards, terminals)

        q1_loss = torch.mean((q1 - target_q).pow(2))
        q2_loss = torch.mean((q2 - target_q).pow(2))
        return q1_loss, q2_loss

    def calc_actor_loss(self, states):
        action, entropy, _ = self.actor.sample(states)

        q1, q2 = self.critic(states, action)
        q = torch.min(q1, q2)

        # actor_loss = torch.mean(-q - self.alpha * entropy)
        actor_loss = (self.alpha * entropy - q).mean()
        return actor_loss, entropy

    def train(self, timestep, batch_size=256):
        if len(self.memory) < batch_size:
            return

        transitions = self.memory.sample(batch_size)
        transitions = Transition(*zip(*transitions))
        if self.cuda:
            states = torch.cat(transitions.state).to('cuda')
            actions = torch.stack(transitions.action).to('cuda')
            next_states = torch.cat(transitions.next_state).to('cuda')
            rewards = torch.stack(transitions.reward).to('cuda')
            terminals = torch.stack(transitions.terminal).to('cuda')
        else:
            states = torch.cat(transitions.state)
            actions = torch.stack(transitions.action)
            next_states = torch.cat(transitions.next_state)
            rewards = torch.stack(transitions.reward)
            terminals = torch.stack(transitions.terminal)
        # Compute target Q func
        q1_loss, q2_loss = self.calc_critic_loss(states, actions, next_states,
                                                 rewards, terminals)
        # Compute actor loss
        actor_loss, mean = self.calc_actor_loss(states)
        update_params(self.q1_optimizer, self.critic.q1, q1_loss)
        update_params(self.q2_optimizer, self.critic.q2, q2_loss)
        update_params(self.actor_optimizer, self.actor, actor_loss)
        # target update
        self.target_update(self.target_critic, self.critic)

        if timestep % 100 and self.writer:
            self.writer.add_scalar('Loss/Actor', actor_loss.item(), timestep)
            self.writer.add_scalar('Loss/Critic', q1_loss.item(), timestep)

    def save_weights(self, path):
        self.actor.save(os.path.join(path, 'actor'))
        self.critic.save(os.path.join(path, 'critic'))