def main(): env = a2c.get_env() print("num actions ", env.action_space.n) shared_actor = nets.Actor(num_actions=env.action_space.n) # close the env after get the num_action of the game shared_critic = nets.Critic() shared_actor_optim = nets.SharedAdam(shared_actor.parameters()) shared_critic_optim = nets.SharedAdam(shared_critic.parameters()) shared_actor.share_memory() shared_critic.share_memory() shared_actor_optim.share_memory() shared_critic_optim.share_memory() num_process = 6 processes = [] for i in range(num_process): processes.append(multiprocessing.Process(target=a3c.learning_thread, args=(shared_actor, shared_critic, shared_actor_optim, shared_critic_optim))) for p in processes: p.start() a2c.test_procedure(shared_actor, env) for p in processes: p.join()
def test_procedure(shared_actor, env): num_actions = env.action_space.n local_actor = nets.Actor(num_actions=num_actions) # load parameters from shared models begin_time = time.time() while True: replay_buffer = utils.ReplayBuffer(size=4, frame_history_len=4) local_actor.load_state_dict(shared_actor.state_dict()) obs = env.reset() rewards = [] while True: replay_buffer.store_frame(obs) states = replay_buffer.encode_recent_observation() states = np.expand_dims(states, axis=0) / 255.0 - .5 logits = local_actor( Variable(torch.FloatTensor(states.astype(np.float32)))) action = utils.epsilon_greedy(logits, num_actions=env.action_space.n, epsilon=-1.) obs, reward, done, info = env.step(action) rewards.append(reward) if done: print("Time:{}, computer:{}, agent:{}".format( time.time() - begin_time, sum(np.array(rewards) == -1), sum(np.array(rewards) == 1))) break
def learning_thread(shared_actor, shared_critic, shared_actor_optim, shared_critic_optim, exploration=LinearSchedule(1000000, 0.1), gamma=0.99, frame_history_len=4): #### # 1. build a local model # 2. synchronize the shared model parameters and local model # 3. choose an action based on observation # 4. take an action and get the reward and the next observation # 5. calculate the target, and accumulate the gradient # 6. update the global model #### # prepare environment env = get_env() obs = env.reset() num_actions = env.action_space.n # prepare local model local_actor = nets.Actor(num_actions=num_actions) local_critic = nets.Critic() # criterion criterion = nn.MSELoss(size_average=False) # load parameters from shared models local_actor.load_state_dict(shared_actor.state_dict()) local_critic.load_state_dict(shared_critic.state_dict()) replay_buffer = utils.ReplayBuffer(size=4, frame_history_len=frame_history_len) # idx = replay_buffer.store_frame(obs) num_n_steps = 4 for i in itertools.count(): states = [] actions = [] next_states = [] dones = [] rewards = [] for i in range(num_n_steps): replay_buffer.store_frame(obs) state = replay_buffer.encode_recent_observation() state = np.expand_dims(state, axis=0) / 255.0 - .5 state = Variable(torch.from_numpy(state.astype(np.float32)), volatile=True) logits = local_actor(state) action = utils.epsilon_greedy(logits, num_actions=num_actions, epsilon=exploration(i)) next_obs, reward, done, info = env.step(action) replay_buffer.store_frame(next_obs) # store the states for get the gradients states.append(state) actions.append(action) dones.append(done) rewards.append(reward) next_states.append(replay_buffer.encode_recent_observation()) if done: break # compute targets and compute the critic's gradient # from numpy to torch.Variable cur_states = np.array(states) / 255.0 - .5 cur_states = Variable(torch.FloatTensor(cur_states.astype(np.float32))) next_states = np.array(next_states) / 255.0 - .5 next_states = Variable(torch.FloatTensor(next_states.astype( np.float32)), volatile=True) not_done_mask = torch.FloatTensor(1 - np.array(dones).astype( dtype=np.float32)).view_(-1, 1) rewards = torch.FloatTensor(np.array(rewards).astype( np.float32)).view_(-1, 1) values = local_critic(next_states) targets = values.data.mul_(not_done_mask).mul_(gamma) targets = targets.add_(rewards)
# obs, reward, done, info = env.step(action) # rewards.append(reward) # if done: # print("Time:{}, computer:{}, agent:{}".format(time.time() - begin_time, # sum(np.array(rewards) == -1), # sum(np.array(rewards) == 1))) # break if __name__ == '__main__': from torch import multiprocessing env = a3c.get_env() print("num actions ", env.action_space.n) shared_actor = nets.Actor(num_actions=env.action_space.n) # close the env after get the num_action of the game shared_critic = nets.Critic() # shared_actor_optim = nets.SharedAdam(shared_actor.parameters()) # shared_critic_optim = nets.SharedAdam(shared_critic.parameters()) shared_actor.share_memory() shared_critic.share_memory() # shared_actor_optim.share_memory() # shared_critic_optim.share_memory() p = multiprocessing.Process(target=test_procedure, args=(shared_actor, a3c.get_env())) p.start() p.join()