def run(): np.random.seed(args.seed) torch.manual_seed(args.seed) gym.logger.set_level(40) env = gym.make(args.env_name) state_size = env.observation_space.shape[0] action_size = env.action_space.shape[0] state_stat = RunningStat(env.observation_space.shape, eps=1e-2) action_space = env.action_space policy = Policy(state_size, action_size, args.hidden_size, action_space.low, action_space.high) num_params = policy.num_params optim = Adam(num_params, args.lr) ray.init(num_cpus=args.num_parallel) return_list = [] for epoch in range(100000): ##################################### ### Rollout and Update State Stat ### ##################################### policy.set_state_stat(state_stat.mean, state_stat.std) # set diff params (mirror sampling) assert args.episodes_per_batch % 2 == 0 diff_params = torch.empty((args.episodes_per_batch, num_params), dtype=torch.float) diff_params_pos = torch.randn(args.episodes_per_batch // 2, num_params) * args.noise_std diff_params[::2] = diff_params_pos diff_params[1::2] = -diff_params_pos rets = [] num_episodes_popped = 0 num_timesteps_popped = 0 while num_episodes_popped < args.episodes_per_batch \ and num_timesteps_popped < args.timesteps_per_batch: #or num_timesteps_popped < args.timesteps_per_batch: results = [] for i in range(min(args.episodes_per_batch, 500)): # set policy randomized_policy = deepcopy(policy) randomized_policy.add_params(diff_params[num_episodes_popped + i]) # rollout results.append( rollout.remote(randomized_policy, args.env_name, seed=np.random.randint(0, 10000000))) for result in results: ret, timesteps, states = ray.get(result) rets.append(ret) # update state stat if states is not None: state_stat.increment(states.sum(axis=0), np.square(states).sum(axis=0), states.shape[0]) num_timesteps_popped += timesteps num_episodes_popped += 1 rets = np.array(rets, dtype=np.float32) diff_params = diff_params[:num_episodes_popped] best_policy_idx = np.argmax(rets) best_policy = deepcopy(policy) best_policy.add_params(diff_params[best_policy_idx]) best_rets = [ rollout.remote(best_policy, args.env_name, seed=np.random.randint(0, 10000000), calc_state_stat_prob=0.0, test=True) for _ in range(10) ] best_rets = np.average(ray.get(best_rets)) print('epoch:', epoch, 'mean:', np.average(rets), 'max:', np.max(rets), 'best:', best_rets) with open(args.outdir + '/return.csv', 'w') as f: return_list.append( [epoch, np.max(rets), np.average(rets), best_rets]) writer = csv.writer(f, lineterminator='\n') writer.writerows(return_list) plt.figure() sns.lineplot(data=np.array(return_list)[:, 1:]) plt.savefig(args.outdir + '/return.png') plt.close('all') ############# ### Train ### ############# fitness = compute_centered_ranks(rets).reshape(-1, 1) if args.weight_decay > 0: #l2_decay = args.weight_decay * ((policy.get_params() + diff_params)**2).mean(dim=1, keepdim=True).numpy() l1_decay = args.weight_decay * (policy.get_params() + diff_params).mean( dim=1, keepdim=True).numpy() fitness += l1_decay grad = (fitness * diff_params.numpy()).mean(axis=0) policy = optim.update(policy, -grad)
def run(): np.random.seed(args.seed) torch.manual_seed(args.seed) gym.logger.set_level(40) env = gym.make(args.env_name) state_size = env.observation_space.shape[0] action_size = env.action_space.shape[0] state_stat = RunningStat( env.observation_space.shape, eps=1e-2 ) action_space = env.action_space policy = Policy(state_size, action_size, args.hidden_size, action_space.low, action_space.high) num_params = policy.num_params es = cma.CMAEvolutionStrategy([0] * num_params, args.sigma_init, {'popsize': args.popsize, }) ray.init(num_cpus=args.num_parallel) return_list = [] for epoch in range(100000): ##################################### ### Rollout and Update State Stat ### ##################################### solutions = np.array(es.ask(), dtype=np.float32) policy.set_state_stat(state_stat.mean, state_stat.std) rets = [] results = [] for i in range(args.popsize): # set policy randomized_policy = deepcopy(policy) randomized_policy.set_params(solutions[i]) # rollout results.append(rollout.remote(randomized_policy, args.env_name, seed=np.random.randint(0,10000000))) for result in results: ret, timesteps, states = ray.get(result) rets.append(ret) # update state stat if states is not None: state_stat.increment(states.sum(axis=0), np.square(states).sum(axis=0), states.shape[0]) rets = np.array(rets, dtype=np.float32) best_policy_idx = np.argmax(rets) best_policy = deepcopy(policy) best_policy.set_params(solutions[best_policy_idx]) best_rets = [rollout.remote(best_policy, args.env_name, seed=np.random.randint(0,10000000), calc_state_stat_prob=0.0, test=True) for _ in range(10)] best_rets = np.average(ray.get(best_rets)) print('epoch:', epoch, 'mean:', np.average(rets), 'max:', np.max(rets), 'best:', best_rets) with open(args.outdir + '/return.csv', 'w') as f: return_list.append([epoch, np.max(rets), np.average(rets), best_rets]) writer = csv.writer(f, lineterminator='\n') writer.writerows(return_list) plt.figure() sns.lineplot(data=np.array(return_list)[:,1:]) plt.savefig(args.outdir + '/return.png') plt.close('all') ############# ### Train ### ############# ranks = compute_centered_ranks(rets) fitness = ranks if args.weight_decay > 0: l2_decay = compute_weight_decay(args.weight_decay, solutions) fitness -= l2_decay # convert minimize to maximize es.tell(solutions, fitness)