Example #1
0
class DQNAgent(object):
    def __init__(self, env, args, work_dir):
        self.env = env
        self.args = args
        self.work_dir = work_dir

        self.n_action = self.env.action_space.n
        self.arr_actions = np.arange(self.n_action)
        self.memory = ReplayMemory(self.args.buffer_size, self.args.device)
        self.qNetwork = ValueNetwork(self.n_action,
                                     self.env).to(self.args.device)
        self.targetNetwork = ValueNetwork(self.n_action,
                                          self.env).to(self.args.device)
        self.qNetwork.train()
        self.targetNetwork.eval()
        self.optimizer = optim.RMSprop(self.qNetwork.parameters(),
                                       lr=0.00025,
                                       eps=0.001,
                                       alpha=0.95)
        self.crit = nn.MSELoss()
        self.eps = max(self.args.eps, self.args.eps_min)
        self.eps_delta = (
            self.eps - self.args.eps_min) / self.args.exploration_decay_speed

    def reset(self):
        return torch.cat([preprocess_state(self.env.reset(), self.env)] * 4, 1)

    def select_action(self, state):
        action_prob = np.zeros(self.n_action, np.float32)
        action_prob.fill(self.eps / self.n_action)
        max_q, max_q_index = self.qNetwork(Variable(state.to(
            self.args.device))).data.cpu().max(1)
        action_prob[max_q_index[0]] += 1 - self.eps
        action = np.random.choice(self.arr_actions, p=action_prob)
        next_state, reward, done, _ = self.env.step(action)
        next_state = torch.cat(
            [state.narrow(1, 1, 3),
             preprocess_state(next_state, self.env)], 1)
        self.memory.push(
            (state, torch.LongTensor([int(action)]), torch.Tensor([reward]),
             next_state, torch.Tensor([done])))
        return next_state, reward, done, max_q[0]

    def run(self):
        state = self.reset()
        # init buffer
        for _ in range(self.args.buffer_init_size):
            next_state, _, done, _ = self.select_action(state)
            state = self.reset() if done else next_state

        total_frame = 0
        reward_list = np.zeros(self.args.log_size, np.float32)
        qval_list = np.zeros(self.args.log_size, np.float32)

        start_time = time.time()

        for epi in count():
            reward_list[epi % self.args.log_size] = 0
            qval_list[epi % self.args.log_size] = -1e9
            state = self.reset()
            done = False
            ep_len = 0

            if epi % self.args.save_freq == 0:
                model_file = os.path.join(self.work_dir, 'model.th')
                with open(model_file, 'wb') as f:
                    torch.save(self.qNetwork, f)

            while not done:
                if total_frame % self.args.sync_period == 0:
                    self.targetNetwork.load_state_dict(
                        self.qNetwork.state_dict())

                self.eps = max(self.args.eps_min, self.eps - self.eps_delta)
                next_state, reward, done, qval = self.select_action(state)
                reward_list[epi % self.args.log_size] += reward
                qval_list[epi % self.args.log_size] = max(
                    qval_list[epi % self.args.log_size], qval)
                state = next_state

                total_frame += 1
                ep_len += 1

                if ep_len % self.args.learn_freq == 0:
                    batch_state, batch_action, batch_reward, batch_next_state, batch_done = self.memory.sample(
                        self.args.batch_size)
                    batch_q = self.qNetwork(batch_state).gather(
                        1, batch_action.unsqueeze(1)).squeeze(1)
                    batch_next_q = self.targetNetwork(batch_next_state).detach(
                    ).max(1)[0] * self.args.gamma * (1 - batch_done)
                    loss = self.crit(batch_q, batch_reward + batch_next_q)
                    self.optimizer.zero_grad()

                    loss.backward()
                    self.optimizer.step()

            output_str = 'episode %d frame %d time %.2fs cur_rew %.3f mean_rew %.3f cur_maxq %.3f mean_maxq %.3f' % (
                epi, total_frame, time.time() - start_time,
                reward_list[epi % self.args.log_size], np.mean(reward_list),
                qval_list[epi % self.args.log_size], np.mean(qval_list))
            print(output_str)
            logging.info(output_str)
Example #2
0
class Agent():
    def __init__(self, state_size, action_size, num_agents):
        state_dim = state_size
        #agent_input_state_dim = state_size*2 # Previos state is passed in with with the current state.
        action_dim = action_size

        self.num_agents = num_agents

        max_size = 100000  ###
        self.replay = Replay(max_size)

        hidden_dim = 128

        self.critic_net = ValueNetwork(state_dim, action_dim,
                                       hidden_dim).to(device)
        self.target_critic_net = ValueNetwork(state_dim, action_dim,
                                              hidden_dim).to(device)

        self.actor_net = PolicyNetwork(state_dim, action_dim,
                                       hidden_dim).to(device)
        self.target_actor_net = PolicyNetwork(state_dim, action_dim,
                                              hidden_dim).to(device)

        for target_param, param in zip(self.target_critic_net.parameters(),
                                       self.critic_net.parameters()):
            target_param.data.copy_(param.data)

        for target_param, param in zip(self.target_actor_net.parameters(),
                                       self.actor_net.parameters()):
            target_param.data.copy_(param.data)

        self.critic_optimizer = optim.Adam(self.critic_net.parameters(),
                                           lr=CRITIC_LEARNING_RATE)
        self.actor_optimizer = optim.Adam(self.actor_net.parameters(),
                                          lr=ACTOR_LEARNING_RATE)

    def get_action(self, state):
        return self.actor_net.get_action(state)[0]

    def add_replay(self, state, action, reward, next_state, done):
        for i in range(self.num_agents):
            self.replay.add(state[i], action[i], reward[i], next_state[i],
                            done[i])

    def learning_step(self):

        #Check if relay buffer contains enough samples for 1 batch
        if (self.replay.cursize < BATCH_SIZE):
            return

        #Get Samples
        state, action, reward, next_state, done = self.replay.get(BATCH_SIZE)

        #calculate loss
        actor_loss = self.critic_net(state, self.actor_net(state))
        actor_loss = -actor_loss.mean()

        next_action = self.target_actor_net(next_state)
        target_value = self.target_critic_net(next_state, next_action.detach())
        expected_value = reward + (1.0 - done) * DISCOUNT_RATE * target_value

        value = self.critic_net(state, action)
        critic_loss = F.mse_loss(value, expected_value.detach())

        #backprop
        self.actor_optimizer.zero_grad()
        actor_loss.backward()
        self.actor_optimizer.step()

        self.critic_optimizer.zero_grad()
        critic_loss.backward()
        self.critic_optimizer.step()

        #soft update
        self.soft_update(self.critic_net, self.target_critic_net, TAU)
        self.soft_update(self.actor_net, self.target_actor_net, TAU)

    def save(self, name):
        torch.save(self.critic_net.state_dict(), name + "_critic")
        torch.save(self.actor_net.state_dict(), name + "_actor")

    def load(self, name):
        self.critic_net.load_state_dict(torch.load(name + "_critic"))
        self.critic_net.eval()
        self.actor_net.load_state_dict(torch.load(name + "_actor"))
        self.actor_net.eval()

        for target_param, param in zip(self.target_critic_net.parameters(),
                                       self.critic_net.parameters()):
            target_param.data.copy_(param.data)

        for target_param, param in zip(self.target_actor_net.parameters(),
                                       self.actor_net.parameters()):
            target_param.data.copy_(param.data)

    def soft_update(self, local_model, target_model, tau):
        """Soft update model parameters.
        θ_target = τ*θ_local + (1 - τ)*θ_target

        Params
        ======
            local_model (PyTorch model): weights will be copied from
            target_model (PyTorch model): weights will be copied to
            tau (float): interpolation parameter 
        """
        for target_param, local_param in zip(target_model.parameters(),
                                             local_model.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1.0 - tau) * target_param.data)
Example #3
0
class SAC:
    def __init__(self, env_name, n_states, n_actions, memory_size, batch_size,
                 gamma, alpha, lr, action_bounds, reward_scale):
        self.env_name = env_name
        self.n_states = n_states
        self.n_actions = n_actions
        self.memory_size = memory_size
        self.batch_size = batch_size
        self.gamma = gamma
        self.alpha = alpha
        self.lr = lr
        self.action_bounds = action_bounds
        self.reward_scale = reward_scale
        self.memory = Memory(memory_size=self.memory_size)

        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.policy_network = PolicyNetwork(
            n_states=self.n_states,
            n_actions=self.n_actions,
            action_bounds=self.action_bounds).to(self.device)
        self.q_value_network1 = QvalueNetwork(n_states=self.n_states,
                                              n_actions=self.n_actions).to(
                                                  self.device)
        self.q_value_network2 = QvalueNetwork(n_states=self.n_states,
                                              n_actions=self.n_actions).to(
                                                  self.device)
        self.value_network = ValueNetwork(n_states=self.n_states).to(
            self.device)
        self.value_target_network = ValueNetwork(n_states=self.n_states).to(
            self.device)
        self.value_target_network.load_state_dict(
            self.value_network.state_dict())
        self.value_target_network.eval()

        self.value_loss = torch.nn.MSELoss()
        self.q_value_loss = torch.nn.MSELoss()

        self.value_opt = Adam(self.value_network.parameters(), lr=self.lr)
        self.q_value1_opt = Adam(self.q_value_network1.parameters(),
                                 lr=self.lr)
        self.q_value2_opt = Adam(self.q_value_network2.parameters(),
                                 lr=self.lr)
        self.policy_opt = Adam(self.policy_network.parameters(), lr=self.lr)

    def store(self, state, reward, done, action, next_state):
        state = from_numpy(state).float().to("cpu")
        reward = torch.Tensor([reward]).to("cpu")
        done = torch.Tensor([done]).to("cpu")
        action = torch.Tensor([action]).to("cpu")
        next_state = from_numpy(next_state).float().to("cpu")
        self.memory.add(state, reward, done, action, next_state)

    def unpack(self, batch):
        batch = Transition(*zip(*batch))

        states = torch.cat(batch.state).view(self.batch_size,
                                             self.n_states).to(self.device)
        rewards = torch.cat(batch.reward).view(self.batch_size,
                                               1).to(self.device)
        dones = torch.cat(batch.done).view(self.batch_size, 1).to(self.device)
        actions = torch.cat(batch.action).view(-1,
                                               self.n_actions).to(self.device)
        next_states = torch.cat(batch.next_state).view(
            self.batch_size, self.n_states).to(self.device)

        return states, rewards, dones, actions, next_states

    def train(self):
        if len(self.memory) < self.batch_size:
            return 0, 0, 0
        else:
            batch = self.memory.sample(self.batch_size)
            states, rewards, dones, actions, next_states = self.unpack(batch)

            # Calculating the value target
            reparam_actions, log_probs = self.policy_network.sample_or_likelihood(
                states)
            q1 = self.q_value_network1(states, reparam_actions)
            q2 = self.q_value_network2(states, reparam_actions)
            q = torch.min(q1, q2)
            target_value = q.detach() - self.alpha * log_probs.detach()

            value = self.value_network(states)
            value_loss = self.value_loss(value, target_value)

            # Calculating the Q-Value target
            with torch.no_grad():
                target_q = self.reward_scale * rewards + \
                           self.gamma * self.value_target_network(next_states) * (1 - dones)
            q1 = self.q_value_network1(states, actions)
            q2 = self.q_value_network2(states, actions)
            q1_loss = self.q_value_loss(q1, target_q)
            q2_loss = self.q_value_loss(q2, target_q)

            policy_loss = (self.alpha * log_probs - q).mean()

            self.policy_opt.zero_grad()
            policy_loss.backward()
            self.policy_opt.step()

            self.value_opt.zero_grad()
            value_loss.backward()
            self.value_opt.step()

            self.q_value1_opt.zero_grad()
            q1_loss.backward()
            self.q_value1_opt.step()

            self.q_value2_opt.zero_grad()
            q2_loss.backward()
            self.q_value2_opt.step()

            self.soft_update_target_network(self.value_network,
                                            self.value_target_network)

            return value_loss.item(), 0.5 * (
                q1_loss + q2_loss).item(), policy_loss.item()

    def choose_action(self, states):
        states = np.expand_dims(states, axis=0)
        states = from_numpy(states).float().to(self.device)
        action, _ = self.policy_network.sample_or_likelihood(states)
        return action.detach().cpu().numpy()[0]

    @staticmethod
    def soft_update_target_network(local_network, target_network, tau=0.005):
        for target_param, local_param in zip(target_network.parameters(),
                                             local_network.parameters()):
            target_param.data.copy_(tau * local_param.data +
                                    (1 - tau) * target_param.data)

    def save_weights(self):
        torch.save(self.policy_network.state_dict(),
                   self.env_name + "_weights.pth")

    def load_weights(self):
        self.policy_network.load_state_dict(
            torch.load(self.env_name + "_weights.pth"))

    def set_to_eval_mode(self):
        self.policy_network.eval()