Ejemplo n.º 1
0
class QAgent():
    def __init__(self, env, dim_maker, record_dir):
        self.memory = ExperienceMemory(batch_size=BATCH_SIZE,
                                       msize=BUFFER_SIZE)
        self.env = env
        self.dim_maker = dim_maker
        self.eps_handler = SoftEpsilonDecay(1.0, 1e-3, 0.995)
        self.local_network = CategoricalActorCritic(state_size=env.obs_dim,
                                                    action_size=env.act_dim,
                                                    shared_feature_size=8,
                                                    fc_units=[64, 32])

        self.target_network = CategoricalActorCritic(state_size=env.obs_dim,
                                                     action_size=env.act_dim,
                                                     shared_feature_size=8,
                                                     fc_units=[64, 32])
        self.optimizer = optim.Adam(self.local_network.parameters(), lr=NAV_LR)
        soft_update(target_net=self.target_network,
                    local_net=self.local_network,
                    tau=1.0)
        #network_dict = {"local": self.local_network, "target": self.target_network}
        #self.sr_service = SaveRestoreService(record_dir, network_dict)
        #self.sr_service.restore()

    def act(self, obs):
        _, actions_t, _, _ = self.local_network.forward(obs)
        return self.dim_maker.env_in(actions_t)
        '''
        if random.random() >= self.eps_handler.eps:
            estimate_v_t, _, _, _ = self.local_network.forward(obs)
            estimate_v = self.dim_maker.agent_out_to_np(estimate_v_t)
            action = np.argmax(estimate_v, axis=1)
        else:
            action = random.choice(np.arange(self.env.act_dim))        '''

        action = np.reshape(action, (-1, 1))
        return action

    def update(self, state, action, reward, next_state, done, episode):
        self.eps_handler.decay()
        self.memory.add(state, action, reward, next_state, done)
        if len(self.memory) < BATCH_SIZE:
            return
        b_states_t, b_actions_t, b_rewards_t, b_next_states_t, b_dones_t = self.memory.sample(
            self.dim_maker)
        vnext_target_t, _, _, _ = self.target_network.forward(b_next_states_t)
        max_vnext_target_t = vnext_target_t.max(1)[0].reshape(-1, 1)
        vtarget_t = b_rewards_t + DISCOUNT_RATE * (
            1 - b_dones_t) * max_vnext_target_t

        vlocal_t, _, _, _ = self.local_network.forward(b_states_t)
        vlocal_t = vlocal_t.gather(1, (b_actions_t.long()))
        loss = F.mse_loss(vlocal_t, vtarget_t.detach())
        self.optimizer.zero_grad()
        loss.backward()
        self.optimizer.step()
        soft_update(target_net=self.target_network,
                    local_net=self.local_network,
                    tau=0.1)
Ejemplo n.º 2
0
 def __init__(self, env, dim_maker, record_dir):
     self.memory = ExperienceMemory(batch_size=BATCH_SIZE, msize=BUFFER_SIZE)
     self.env = env
     self.dim_maker = dim_maker
     self.eps_handler = SoftEpsilonDecay(1.0, 1e-3, 0.995)
     self.local_network = NavModel(env.obs_dim, env.act_dim)
     self.target_network = NavModel(env.obs_dim, env.act_dim)
     self.optimizer = optim.Adam(self.local_network.parameters(), lr=NAV_LR)
     soft_update(target_net=self.target_network, local_net=self.local_network, tau=1.0)
     network_dict = {"local": self.local_network, "target": self.target_network}
     self.sr_service = SaveRestoreService(record_dir, network_dict)
     self.sr_service.restore()
Ejemplo n.º 3
0
 def __init__(self, env_driver, dim_maker):
     super(TennisMultiAgent, self).__init__()
     self.env_driver = env_driver
     self.noise = OUNoise((env_driver.num_agents, env_driver.act_dim),
                          sigma_decay=0.9995,
                          seed=RANDOM_SEED)
     self.memory = ExperienceMemory(batch_size=BATCH_SIZE,
                                    msize=BUFFER_SIZE)
     self.ddpg_agent_list = [
         TennisAgent(env_driver, dim_maker)
         for i in range(env_driver.num_agents)
     ]
     self.dim_maker = dim_maker
Ejemplo n.º 4
0
    def __init__(self, env, dim_maker, record_dir):
        self.memory = ExperienceMemory(batch_size=BATCH_SIZE,
                                       msize=BUFFER_SIZE)
        self.env = env
        self.dim_maker = dim_maker
        self.eps_handler = SoftEpsilonDecay(1.0, 1e-3, 0.995)
        self.local_network = CategoricalActorCritic(state_size=env.obs_dim,
                                                    action_size=env.act_dim,
                                                    shared_feature_size=8,
                                                    fc_units=[64, 32])

        self.target_network = CategoricalActorCritic(state_size=env.obs_dim,
                                                     action_size=env.act_dim,
                                                     shared_feature_size=8,
                                                     fc_units=[64, 32])
        self.optimizer = optim.Adam(self.local_network.parameters(), lr=NAV_LR)
        soft_update(target_net=self.target_network,
                    local_net=self.local_network,
                    tau=1.0)
Ejemplo n.º 5
0
class TennisMultiAgent():
    def __init__(self, env_driver, dim_maker):
        super(TennisMultiAgent, self).__init__()
        self.env_driver = env_driver
        self.noise = OUNoise((env_driver.num_agents, env_driver.act_dim),
                             sigma_decay=0.9995,
                             seed=RANDOM_SEED)
        self.memory = ExperienceMemory(batch_size=BATCH_SIZE,
                                       msize=BUFFER_SIZE)
        self.ddpg_agent_list = [
            TennisAgent(env_driver, dim_maker)
            for i in range(env_driver.num_agents)
        ]
        self.dim_maker = dim_maker

    def reset(self):
        self.noise.reset()

    def save(self, save_dir):
        os.makedirs(save_dir, exist_ok=True)
        for i, agent in enumerate(self.ddpg_agent_list):
            an_filename = os.path.join(save_dir,
                                       "agent_actor_{}.pth".format(i))
            torch.save(agent.actor_local.state_dict(), an_filename)
            cn_filename = os.path.join(save_dir,
                                       "agent_critic_{}.pth".format(i))
            torch.save(agent.critic_local.state_dict(), cn_filename)

    def act_for_env(self, agent_obs):
        env_driver = self.env_driver
        actions = np.zeros((env_driver.num_agents, env_driver.act_dim))
        for i, agent in enumerate(self.ddpg_agent_list):
            obs_t = self.dim_maker.agent_in(agent_obs[i, :].reshape(1, -1))
            actions[i, :] = agent.act_for_env(obs_t)
        noise = self.noise.sample()
        actions += noise
        return np.clip(actions, -1, 1)  # all actions between -1 and 1

    def update(self, states, actions, rewards, next_states, dones):
        self.memory.add(state=states,
                        action=actions,
                        reward=rewards,
                        next_state=next_states,
                        done=dones)
        if len(self.memory) >= BATCH_SIZE:
            for a_idx in range(self.env_driver.num_agents):
                self.__learn(self.memory.sample(self.dim_maker), a_idx)

    def __learn(self, samples, agent_id):
        env_driver = self.env_driver
        b_states_t, b_actions_t, b_rewards_t, b_next_states_t, b_dones_t = samples
        this_agent = self.ddpg_agent_list[agent_id]

        # --------------- the agent's baseline ---------
        # step 1.  get all agents next move from target network.
        a_next_t_list = []
        for i in range(env_driver.num_agents):
            a_next_t_list.append(
                this_agent.actor_target.forward(b_next_states_t[:, i, :]))
        a_next_t = torch.cat(a_next_t_list, dim=1)

        q_next_t = this_agent.critic_target.forward(
            torch.cat([b_next_states_t.reshape(BATCH_SIZE, -1), a_next_t],
                      dim=1))
        r = b_rewards_t[:, agent_id].reshape(-1, 1)
        d_1 = 1 - b_dones_t[:, agent_id].reshape(-1, 1)
        q_target_t = r + DISCOUNT_RATE * q_next_t * (d_1)
        assert (q_target_t.shape[0] == BATCH_SIZE and q_target_t.shape[1] == 1)

        critic_input = torch.cat([
            b_states_t.reshape(BATCH_SIZE, -1),
            b_actions_t.reshape(BATCH_SIZE, -1)
        ],
                                 dim=1)
        assert (critic_input.shape[0] == BATCH_SIZE)
        v = this_agent.critic_local.forward(critic_input)
        assert (v.shape[0] == BATCH_SIZE and v.shape[1] == 1)

        #huber_loss = torch.nn.SmoothL1Loss()
        #critic_loss = huber_loss(v, q_target_t)
        critic_loss = F.mse_loss(v, q_target_t)
        this_agent.critic_optimizer.zero_grad()
        critic_loss.backward()
        this_agent.critic_optimizer.step()

        # optimize single agent.
        this_agent.actor_optimizer.zero_grad()
        action_list = []
        for i, agent in enumerate(self.ddpg_agent_list):
            obs = b_states_t[:, i, :]
            assert (obs.shape[0] == BATCH_SIZE
                    and obs.shape[1] == env_driver.obs_dim)
            action = agent.actor_local.forward(obs)
            if i != agent_id:
                action.detach()
            action_list.append(action)
        actions_t = torch.cat(action_list, dim=1)
        assert (actions_t.shape[0] == BATCH_SIZE and actions_t.shape[1]
                == env_driver.act_dim * env_driver.num_agents)

        critic_input = torch.cat([
            b_states_t.reshape(BATCH_SIZE, -1),
            actions_t.reshape(BATCH_SIZE, -1)
        ],
                                 dim=1)
        assert (critic_input.shape[0] == BATCH_SIZE)
        actor_loss = -this_agent.critic_local.forward(critic_input).mean()
        actor_loss.backward()
        this_agent.actor_optimizer.step()
        self.__update_target()

    def __update_target(self):
        for agent in self.ddpg_agent_list:
            soft_update(agent.actor_target, agent.actor_local, tau=TAU)
            soft_update(agent.critic_target, agent.critic_local, tau=TAU)