def run():
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    gym.logger.set_level(40)
    env = gym.make(args.env_name)
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.shape[0]
    state_stat = RunningStat(env.observation_space.shape, eps=1e-2)
    action_space = env.action_space
    policy = Policy(state_size, action_size, args.hidden_size,
                    action_space.low, action_space.high)
    num_params = policy.num_params
    optim = Adam(num_params, args.lr)

    ray.init(num_cpus=args.num_parallel)

    return_list = []
    for epoch in range(100000):
        #####################################
        ### Rollout and Update State Stat ###
        #####################################

        policy.set_state_stat(state_stat.mean, state_stat.std)

        # set diff params (mirror sampling)
        assert args.episodes_per_batch % 2 == 0
        diff_params = torch.empty((args.episodes_per_batch, num_params),
                                  dtype=torch.float)
        diff_params_pos = torch.randn(args.episodes_per_batch // 2,
                                      num_params) * args.noise_std
        diff_params[::2] = diff_params_pos
        diff_params[1::2] = -diff_params_pos

        rets = []
        num_episodes_popped = 0
        num_timesteps_popped = 0
        while num_episodes_popped < args.episodes_per_batch \
                and num_timesteps_popped < args.timesteps_per_batch:
            #or num_timesteps_popped < args.timesteps_per_batch:
            results = []
            for i in range(min(args.episodes_per_batch, 500)):
                # set policy
                randomized_policy = deepcopy(policy)
                randomized_policy.add_params(diff_params[num_episodes_popped +
                                                         i])
                # rollout
                results.append(
                    rollout.remote(randomized_policy,
                                   args.env_name,
                                   seed=np.random.randint(0, 10000000)))

            for result in results:
                ret, timesteps, states = ray.get(result)
                rets.append(ret)
                # update state stat
                if states is not None:
                    state_stat.increment(states.sum(axis=0),
                                         np.square(states).sum(axis=0),
                                         states.shape[0])

                num_timesteps_popped += timesteps
                num_episodes_popped += 1
        rets = np.array(rets, dtype=np.float32)
        diff_params = diff_params[:num_episodes_popped]

        best_policy_idx = np.argmax(rets)
        best_policy = deepcopy(policy)
        best_policy.add_params(diff_params[best_policy_idx])
        best_rets = [
            rollout.remote(best_policy,
                           args.env_name,
                           seed=np.random.randint(0, 10000000),
                           calc_state_stat_prob=0.0,
                           test=True) for _ in range(10)
        ]
        best_rets = np.average(ray.get(best_rets))

        print('epoch:', epoch, 'mean:', np.average(rets), 'max:', np.max(rets),
              'best:', best_rets)
        with open(args.outdir + '/return.csv', 'w') as f:
            return_list.append(
                [epoch, np.max(rets),
                 np.average(rets), best_rets])
            writer = csv.writer(f, lineterminator='\n')
            writer.writerows(return_list)

            plt.figure()
            sns.lineplot(data=np.array(return_list)[:, 1:])
            plt.savefig(args.outdir + '/return.png')
            plt.close('all')

        #############
        ### Train ###
        #############

        fitness = compute_centered_ranks(rets).reshape(-1, 1)
        if args.weight_decay > 0:
            #l2_decay = args.weight_decay * ((policy.get_params() + diff_params)**2).mean(dim=1, keepdim=True).numpy()
            l1_decay = args.weight_decay * (policy.get_params() +
                                            diff_params).mean(
                                                dim=1, keepdim=True).numpy()
            fitness += l1_decay
        grad = (fitness * diff_params.numpy()).mean(axis=0)
        policy = optim.update(policy, -grad)
def run():
    np.random.seed(args.seed)
    torch.manual_seed(args.seed)

    gym.logger.set_level(40)
    env = gym.make(args.env_name)
    state_size = env.observation_space.shape[0]
    action_size = env.action_space.shape[0]
    state_stat = RunningStat(
        env.observation_space.shape,
        eps=1e-2
    )
    action_space = env.action_space
    policy = Policy(state_size, action_size, args.hidden_size, action_space.low, action_space.high)
    num_params = policy.num_params
    es = cma.CMAEvolutionStrategy([0] * num_params,
                                    args.sigma_init,
                                    {'popsize': args.popsize,
                                        })
    
    ray.init(num_cpus=args.num_parallel)

    return_list = []
    for epoch in range(100000):
        #####################################
        ### Rollout and Update State Stat ###
        #####################################

        solutions = np.array(es.ask(), dtype=np.float32)
        policy.set_state_stat(state_stat.mean, state_stat.std)

        rets = []
        results = []
        for i in range(args.popsize):
            # set policy
            randomized_policy = deepcopy(policy)
            randomized_policy.set_params(solutions[i])
            # rollout
            results.append(rollout.remote(randomized_policy, args.env_name, seed=np.random.randint(0,10000000)))
        
        for result in results:
            ret, timesteps, states = ray.get(result)
            rets.append(ret)
            # update state stat
            if states is not None:
                state_stat.increment(states.sum(axis=0), np.square(states).sum(axis=0), states.shape[0])
            
        rets = np.array(rets, dtype=np.float32)
        
        best_policy_idx = np.argmax(rets)
        best_policy = deepcopy(policy)
        best_policy.set_params(solutions[best_policy_idx])
        best_rets = [rollout.remote(best_policy, args.env_name, seed=np.random.randint(0,10000000), calc_state_stat_prob=0.0, test=True) for _ in range(10)]
        best_rets = np.average(ray.get(best_rets))
        
        print('epoch:', epoch, 'mean:', np.average(rets), 'max:', np.max(rets), 'best:', best_rets)
        with open(args.outdir + '/return.csv', 'w') as f:
            return_list.append([epoch, np.max(rets), np.average(rets), best_rets])
            writer = csv.writer(f, lineterminator='\n')
            writer.writerows(return_list)

            plt.figure()
            sns.lineplot(data=np.array(return_list)[:,1:])
            plt.savefig(args.outdir + '/return.png')
            plt.close('all')
        

        #############
        ### Train ###
        #############

        ranks = compute_centered_ranks(rets)
        fitness = ranks
        if args.weight_decay > 0:
            l2_decay = compute_weight_decay(args.weight_decay, solutions)
            fitness -= l2_decay
        # convert minimize to maximize
        es.tell(solutions, fitness)