def post_evaluate(models_path, sigma, n_post_episodes=5, add_noise=False): # print('----------------Post evaluation----------------') policy_path = models_path + "_policy" value_path = models_path + "_value" if args.use_parameter_noise: policy_post = PolicyLayerNorm(num_inputs, num_actions) value_post = Value(num_inputs) else: policy_post = Policy(num_inputs, num_actions) value_post = Value(num_inputs) # print('------------------') value_post.load_state_dict(torch.load(value_path)) policy_post.load_state_dict(torch.load(policy_path)) reward_post = 0 for i in range(n_post_episodes): state = env.reset() ##seeding # env.seed(i) # torch.manual_seed(i) # state = running_state(state) for t in range(1000): if args.use_parameter_noise and add_noise: action = select_action(policy_post, state, sigma, add_noise=True) else: action = select_action(policy_post, state) action = action.data[0].numpy() next_state, reward, done, _ = env.step(action) reward_post += reward # next_state = running_state(next_state) if done: break # state = running_state(next_state) state = next_state print('___Post evaluation reward___') print(reward_post / n_post_episodes) return reward_post / n_post_episodes
def train(rank, params, shared_p, shared_v, optimizer_p, optimizer_v): torch.manual_seed(params.seed + rank) env = gym.make(params.env_name) num_inputs = env.observation_space.shape[0] num_outputs = env.action_space.shape[0] policy = Policy(num_inputs, num_outputs) value = Value(num_inputs) memory = ReplayMemory(1e6) batch_size = 10000 state = env.reset() state = Variable(torch.Tensor(state).unsqueeze(0)) done = True episode_length = 0 while True: episode_length += 1 policy.load_state_dict(shared_p.state_dict()) value.load_state_dict(shared_v.state_dict()) w = -1 while w < batch_size: states = [] actions = [] rewards = [] values = [] returns = [] advantages = [] # Perform K steps for step in range(params.num_steps): w += 1 states.append(state) mu, sigma_sq = policy(state) eps = torch.randn(mu.size()) action = (mu + sigma_sq.sqrt()*Variable(eps)) actions.append(action) v = value(state) values.append(v) env_action = action.data.squeeze().numpy() state, reward, done, _ = env.step(env_action) done = (done or episode_length >= params.max_episode_length) reward = max(min(reward, 1), -1) rewards.append(reward) if done: episode_length = 0 state = env.reset() state = Variable(torch.Tensor(state).unsqueeze(0)) if done: break R = torch.zeros(1, 1) if not done: v = value(state) R = v.data # compute returns and advantages: values.append(Variable(R)) R = Variable(R) for i in reversed(range(len(rewards))): R = params.gamma * R + rewards[i] returns.insert(0, R) A = R - values[i] advantages.insert(0, A) # store usefull info: memory.push([states, actions, returns, advantages]) batch_states, batch_actions, batch_returns, batch_advantages = memory.sample(batch_size) # policy grad updates: mu_old, sigma_sq_old = policy(batch_states) probs_old = normal(batch_actions, mu_old, sigma_sq_old) policy_new = Policy(num_inputs, num_outputs) kl = 0. kl_coef = 1. kl_target = Variable(torch.Tensor([params.kl_target])) for m in range(100): policy_new.load_state_dict(shared_p.state_dict()) mu_new, sigma_sq_new = policy_new(batch_states) probs_new = normal(batch_actions, mu_new, sigma_sq_new) policy_loss = torch.mean(batch_advantages * torch.sum(probs_new/probs_old,1)) kl = torch.mean(probs_old * torch.log(probs_old/probs_new)) kl_loss = kl_coef * kl + \ params.ksi * torch.clamp(kl-2*kl_target, max=0)**2 total_policy_loss = - policy_loss + kl_loss if kl > 4*kl_target: break # assynchronous update: optimizer_p.zero_grad() total_policy_loss.backward() ensure_shared_grads(policy_new, shared_p) optimizer_p.step() # value grad updates: for b in range(100): value.load_state_dict(shared_v.state_dict()) v = value(batch_states) value_loss = torch.mean((batch_returns - v)**2) # assynchronous update: optimizer_v.zero_grad() value_loss.backward() ensure_shared_grads(value, shared_v) optimizer_v.step() if kl > params.beta_hight*kl_target: kl_coef *= params.alpha if kl < params.beta_low*kl_target: kl_coef /= params.alpha print("update done !")