Beispiel #1
0
    def learn(self, total_steps):
        """
        The actual training loop
        Returns:
            model: trained model
            avg_reward_hist: list with the average reward per episode at each epoch
            var_dict: dictionary with all locals, for logging/debugging purposes

        """

        # init everything
        # ==============================================================================
        # seed all our RNGs
        env = gym.make(self.env_name, **self.env_config)

        cur_total_steps = 0
        env.seed(self.seed)
        torch.manual_seed(self.seed)
        np.random.seed(self.seed)

        progress_bar = tqdm.tqdm(total=total_steps)
        lr_lookup = make_schedule(self.lr_schedule, total_steps)

        self.sgd_lr = lr_lookup(0)

        progress_bar.update(0)
        early_stop = False
        self.pol_opt = torch.optim.RMSprop(self.model.policy.parameters(),
                                           lr=lr_lookup(cur_total_steps))
        self.val_opt = torch.optim.RMSprop(self.model.value_fn.parameters(),
                                           lr=lr_lookup(cur_total_steps))

        # Train until we hit our total steps or reach our reward threshold
        # ==============================================================================
        while cur_total_steps < total_steps:
            batch_obs = torch.empty(0)
            batch_act = torch.empty(0)
            batch_adv = torch.empty(0)
            batch_discrew = torch.empty(0)
            cur_batch_steps = 0

            # Bail out if we have met out reward threshold
            if len(self.raw_rew_hist) > 2 and self.reward_stop:
                if self.raw_rew_hist[
                        -1] >= self.reward_stop and self.raw_rew_hist[
                            -2] >= self.reward_stop:
                    early_stop = True
                    break

            # construct batch data from rollouts
            # ==============================================================================
            while cur_batch_steps < self.epoch_batch_size:
                ep_obs, ep_act, ep_rew, ep_steps, ep_term = do_rollout(
                    env, self.model, self.env_no_term_steps)

                cur_batch_steps += ep_steps
                cur_total_steps += ep_steps

                #print(sum(ep_rew).item())
                self.raw_rew_hist.append(sum(ep_rew).item())
                #print("Rew:", sum(ep_rew).item())
                batch_obs = torch.cat((batch_obs, ep_obs.clone()))
                batch_act = torch.cat((batch_act, ep_act.clone()))

                if self.normalize_return:
                    self.rew_std = update_std(ep_rew, self.rew_std,
                                              cur_total_steps)
                    ep_rew = ep_rew / (self.rew_std + 1e-6)

                if ep_term:
                    ep_rew = torch.cat((ep_rew, torch.zeros(1, 1)))
                else:
                    ep_rew = torch.cat((ep_rew, self.model.value_fn(
                        ep_obs[-1]).detach().reshape(1, 1).clone()))

                ep_discrew = discount_cumsum(ep_rew, self.gamma)[:-1]
                batch_discrew = torch.cat((batch_discrew, ep_discrew.clone()))

                with torch.no_grad():
                    ep_val = torch.cat((self.model.value_fn(ep_obs),
                                        ep_rew[-1].reshape(1, 1).clone()))
                    deltas = ep_rew[:-1] + self.gamma * ep_val[1:] - ep_val[:-1]

                ep_adv = discount_cumsum(deltas, self.gamma * self.lam)
                # make sure our advantages are zero mean and unit variance

                batch_adv = torch.cat((batch_adv, ep_adv.clone()))

            # PostProcess epoch and update weights
            # ==============================================================================
            if self.normalize_adv:
                # adv_mean = update_mean(batch_adv, adv_mean, cur_total_steps)
                # adv_var = update_std(batch_adv, adv_var, cur_total_steps)
                batch_adv = (batch_adv - batch_adv.mean()) / (batch_adv.std() +
                                                              1e-6)

            # Update the policy using the PPO loss
            for pol_epoch in range(self.sgd_epochs):
                pol_loss, approx_kl = self.policy_update(
                    batch_act, batch_obs, batch_adv)
                if approx_kl > self.target_kl:
                    print("KL Stop")
                    break

            for val_epoch in range(self.sgd_epochs):
                val_loss = self.value_update(batch_obs, batch_discrew)

            # update observation mean and variance

            if self.normalize_obs:
                self.obs_mean = update_mean(batch_obs, self.obs_mean,
                                            cur_total_steps)
                self.obs_std = update_std(batch_obs, self.obs_std,
                                          cur_total_steps)
                self.model.policy.state_means = self.obs_mean
                self.model.value_fn.state_means = self.obs_mean
                self.model.policy.state_std = self.obs_std
                self.model.value_fn.state_std = self.obs_std

            sgd_lr = lr_lookup(cur_total_steps)

            self.old_model = copy.deepcopy(self.model)
            self.val_loss_hist.append(val_loss.detach())
            self.pol_loss_hist.append(pol_loss.detach())
            self.lrp_hist.append(
                self.pol_opt.state_dict()['param_groups'][0]['lr'])
            self.lrv_hist.append(
                self.val_opt.state_dict()['param_groups'][0]['lr'])
            self.kl_hist.append(approx_kl.detach())
            self.entropy_hist.append(self.model.policy.logstds.detach())

            progress_bar.update(cur_batch_steps)

        progress_bar.close()
        return self.model, self.raw_rew_hist, locals()
Beispiel #2
0
def ars(env_name, n_epochs, env_config, step_size, n_delta, n_top, exp_noise,
        n_workers, policy, seed):
    torch.autograd.set_grad_enabled(False)  # Gradient free baby!
    pool = Pool(processes=n_workers)

    W = torch.nn.utils.parameters_to_vector(policy.parameters())
    n_param = W.shape[0]

    if env_config is None:
        env_config = {}

    env = gym.make(env_name, **env_config)

    env.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    total_steps = 0
    r_hist = []

    exp_dist = torch.distributions.Normal(torch.zeros(n_delta, n_param),
                                          torch.ones(n_delta, n_param))
    do_rollout_partial = partial(do_rollout_train, env_name, policy)

    for _ in range(n_epochs):

        deltas = exp_dist.sample()

        ###
        pm_W = torch.cat((W + (deltas * exp_noise), W - (deltas * exp_noise)))

        results = pool.map(do_rollout_partial, pm_W)

        states = torch.empty(0)
        p_returns = []
        m_returns = []
        l_returns = []
        top_returns = []

        for p_result, m_result in zip(results[:n_delta], results[n_delta:]):
            ps, pr, plr = p_result
            ms, mr, mlr = m_result

            states = torch.cat((states, ms, ps), dim=0)
            p_returns.append(pr)
            m_returns.append(mr)
            l_returns.append(plr)
            l_returns.append(mlr)
            top_returns.append(max(pr, mr))

        top_idx = sorted(range(len(top_returns)),
                         key=lambda k: top_returns[k],
                         reverse=True)[:n_top]
        p_returns = torch.stack(p_returns)[top_idx]
        m_returns = torch.stack(m_returns)[top_idx]
        l_returns = torch.stack(l_returns)[top_idx]

        r_hist.append(l_returns.mean())
        ###

        W = W + (step_size / (n_delta * torch.cat(
            (p_returns, m_returns)).std() + 1e-6)) * torch.sum(
                (p_returns - m_returns) * deltas[top_idx].T, dim=1)

        ep_steps = states.shape[0]
        policy.state_means = update_mean(states, policy.state_means,
                                         total_steps)
        policy.state_std = update_std(states, policy.state_std, total_steps)
        do_rollout_partial = partial(do_rollout_train, env_name, policy)

        total_steps += ep_steps

        torch.nn.utils.vector_to_parameters(W, policy.parameters())

    return policy, r_hist
Beispiel #3
0
def ppo_dim(
        env_name,
        total_steps,
        model,
        transient_length = 50,
        act_var_schedule=[0.7],
        epoch_batch_size=2048,
        gamma=0.99,
        lam=0.99,
        eps=0.2,
        seed=0,
        pol_batch_size=1024,
        val_batch_size=1024,
        pol_lr=1e-4,
        val_lr=1e-4,
        pol_epochs=10,
        val_epochs=10,
        target_kl=.01,
        use_gpu=False,
        reward_stop=None,
        normalize_return=True,
        env_config={}
):
    """
    Implements proximal policy optimization with clipping

    Args:
        env_name: name of the openAI gym environment to solve
        total_steps: number of timesteps to run the PPO for
        model: model from seagul.rl.models. Contains policy and value fn
        transient_length:
        act_var_schedule: schedule to set the variance of the policy. Will linearly interpolate values
        epoch_batch_size: number of environment steps to take per batch, total steps will be num_epochs*epoch_batch_size
        seed: seed for all the rngs
        gamma: discount applied to future rewards, usually close to 1
        lam: lambda for the Advantage estimation, usually close to 1
        eps: epsilon for the clipping, usually .1 or .2
        pol_batch_size: batch size for policy updates
        val_batch_size: batch size for value function updates
        pol_lr: learning rate for policy pol_optimizer
        val_lr: learning rate of value function pol_optimizer
        pol_epochs: how many epochs to use for each policy update
        val_epochs: how many epochs to use for each value update
        target_kl: max KL before breaking
        use_gpu:  want to use the GPU? set to true
        reward_stop: reward value to stop if we achieve
        normalize_return: should we normalize the return?
        env_config: dictionary containing kwargs to pass to your the environment

    Returns:
        model: trained model
        avg_reward_hist: list with the average reward per episode at each epoch
        var_dict: dictionary with all locals, for logging/debugging purposes

    Example:
        from seagul.rl.algos import ppo
        from seagul.nn import MLP
        from seagul.rl.models import PPOModel
        import torch

        input_size = 3
        output_size = 1
        layer_size = 64
        num_layers = 2

        policy = MLP(input_size, output_size, num_layers, layer_size)
        value_fn = MLP(input_size, 1, num_layers, layer_size)
        model = PPOModel(policy, value_fn)

        model, rews, var_dict = ppo("Pendulum-v0", 10000, model)

    """

    # init everything
    # ==============================================================================
    torch.set_num_threads(1)

    env = gym.make(env_name, **env_config)
    if isinstance(env.action_space, gym.spaces.Box):
        act_size = env.action_space.shape[0]
        act_dtype = torch.double
    else:
        raise NotImplementedError("trying to use unsupported action space", env.action_space)

    actvar_lookup = make_variance_schedule(act_var_schedule, model, total_steps)
    model.action_var = actvar_lookup(0)

    obs_size = env.observation_space.shape[0]
    obs_mean = torch.zeros(obs_size)
    obs_var = torch.ones(obs_size)
    adv_mean = torch.zeros(1)
    adv_var = torch.ones(1)
    rew_mean = torch.zeros(1)
    rew_var = torch.ones(1)

    old_model = pickle.loads(
        pickle.dumps(model)
    )  # copy.deepcopy broke for me with older version of torch. Using pickle for this is weird but works fine
    pol_opt = torch.optim.Adam(model.policy.parameters(), lr=pol_lr)
    val_opt = torch.optim.Adam(model.value_fn.parameters(), lr=val_lr)

    # seed all our RNGs
    env.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    # set defaults, and decide if we are using a GPU or not
    use_cuda = torch.cuda.is_available() and use_gpu
    device = torch.device("cuda:0" if use_cuda else "cpu")

    # init logging stuff
    raw_rew_hist = []
    val_loss_hist = []
    pol_loss_hist = []
    progress_bar = tqdm.tqdm(total=total_steps)
    cur_total_steps = 0
    progress_bar.update(0)
    early_stop = False

    # Train until we hit our total steps or reach our reward threshold
    # ==============================================================================
    while cur_total_steps < total_steps:

        batch_obs = torch.empty(0)
        batch_act = torch.empty(0)
        batch_adv = torch.empty(0)
        batch_discrew = torch.empty(0)
        cur_batch_steps = 0

        # Bail out if we have met out reward threshold
        if len(raw_rew_hist) > 2 and reward_stop:
            if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[-2] >= reward_stop:
                early_stop = True
                break

        # construct batch data from rollouts
        # ==============================================================================
        while cur_batch_steps < epoch_batch_size:

            ep_obs, ep_act, ep_rew, ep_steps = do_rollout(env, model)

            ep_rew /= var_dim(ep_obs[transient_length:],order=1)

            raw_rew_hist.append(sum(ep_rew))
            ep_rew = (ep_rew - ep_rew.mean()) / (ep_rew.std() + 1e-6)

            batch_obs = torch.cat((batch_obs, ep_obs[:-1]))
            batch_act = torch.cat((batch_act, ep_act[:-1]))

            ep_discrew = discount_cumsum(
                ep_rew, gamma
            )  # [:-1] because we appended the value function to the end as an extra reward
            batch_discrew = torch.cat((batch_discrew, ep_discrew[:-1]))

            if normalize_return:
                rew_mean = update_mean(batch_discrew, rew_mean, cur_total_steps)
                rew_var = update_std(batch_discrew, rew_var, cur_total_steps)
                batch_discrew = (batch_discrew - rew_mean) / (rew_var + 1e-6)

            # calculate this episodes advantages
            last_val = model.value_fn(ep_obs[-1]).reshape(-1, 1)
            ep_val = model.value_fn(ep_obs)
            ep_val[-1] = last_val

            deltas = ep_rew[:-1] + gamma * ep_val[1:] - ep_val[:-1]
            ep_adv = discount_cumsum(deltas.detach(), gamma * lam)
            batch_adv = torch.cat((batch_adv, ep_adv))

            cur_batch_steps += ep_steps
            cur_total_steps += ep_steps

        # make sure our advantages are zero mean and unit variance
        adv_mean = update_mean(batch_adv, adv_mean, cur_total_steps)
        adv_var = update_std(batch_adv, adv_var, cur_total_steps)
        batch_adv = (batch_adv - adv_mean) / (adv_var + 1e-6)

        # policy update
        # ========================================================================
        num_mbatch = int(batch_obs.shape[0] / pol_batch_size)

        # Update the policy using the PPO loss
        for pol_epoch in range(pol_epochs):
            for i in range(num_mbatch):
                cur_sample = i * pol_batch_size

                logp = model.get_logp(batch_obs[cur_sample:cur_sample + pol_batch_size],
                                      batch_act[cur_sample:cur_sample + pol_batch_size]).reshape(-1, act_size)
                old_logp = old_model.get_logp(batch_obs[cur_sample:cur_sample + pol_batch_size],
                                              batch_act[cur_sample:cur_sample + pol_batch_size]).reshape(-1, act_size)
                r = torch.exp(logp - old_logp)
                clip_r = torch.clamp(r, 1 - eps, 1 + eps)
                pol_loss = -torch.min(r * batch_adv[cur_sample:cur_sample + pol_batch_size],
                                      clip_r * batch_adv[cur_sample:cur_sample + pol_batch_size]).mean()

                approx_kl = (logp - old_logp).mean()
                if approx_kl > target_kl:
                    break

                pol_opt.zero_grad()
                pol_loss.backward()
                pol_opt.step()

        # value_fn update
        # ========================================================================
        num_mbatch = int(batch_obs.shape[0] / val_batch_size)

        # Update value function with the standard L2 Loss
        for val_epoch in range(val_epochs):
            for i in range(num_mbatch):
                cur_sample = i * pol_batch_size

                # predict and calculate loss for the batch
                val_preds = model.value_fn(batch_obs[cur_sample:cur_sample + pol_batch_size])
                val_loss = ((val_preds - batch_discrew[cur_sample:cur_sample + pol_batch_size]) ** 2).mean()

                # do the normal pytorch update
                val_opt.zero_grad()
                val_loss.backward()
                val_opt.step()

        # update observation mean and variance
        obs_mean = update_mean(batch_obs, obs_mean, cur_total_steps)
        obs_var = update_std(batch_obs, obs_var, cur_total_steps)
        model.policy.state_means = obs_mean
        model.value_fn.state_means = obs_mean
        model.policy.state_std = obs_var
        model.value_fn.state_std = obs_var
        model.action_var = actvar_lookup(cur_total_steps)
        old_model = pickle.loads(pickle.dumps(model))

        val_loss_hist.append(val_loss)
        pol_loss_hist.append(pol_loss)

        progress_bar.update(cur_batch_steps)

    progress_bar.close()
    return model, raw_rew_hist, locals()
Beispiel #4
0
def ars(env_name,
        policy,
        n_epochs,
        n_workers=8,
        step_size=.02,
        n_delta=32,
        n_top=16,
        exp_noise=0.03,
        zero_policy=True,
        postprocess=postprocess_default):
    torch.autograd.set_grad_enabled(False)
    """
    Augmented Random Search
    https://arxiv.org/pdf/1803.07055

    Args:

    Returns:

    Example:
    """

    pool = Pool(processes=n_workers)
    env = gym.make(env_name)
    W = torch.nn.utils.parameters_to_vector(policy.parameters())
    n_param = W.shape[0]

    if zero_policy:
        W = torch.zeros_like(W)

    r_hist = []
    s_mean = torch.zeros(env.observation_space.shape[0])
    s_stdv = torch.ones(env.observation_space.shape[0])

    total_steps = 0
    exp_dist = torch.distributions.Normal(torch.zeros(n_delta, n_param),
                                          torch.ones(n_delta, n_param))
    do_rollout_partial = partial(do_rollout_train, env_name, policy,
                                 postprocess)

    for _ in range(n_epochs):

        deltas = exp_dist.sample()
        pm_W = torch.cat((W + (deltas * exp_noise), W - (deltas * exp_noise)))

        results = pool.map(do_rollout_partial, pm_W)

        states = torch.empty(0)
        p_returns = []
        m_returns = []
        l_returns = []
        top_returns = []

        for p_result, m_result in zip(results[:n_delta], results[n_delta:]):
            ps, pr, plr = p_result
            ms, mr, mlr = m_result

            states = torch.cat((states, ms, ps), dim=0)
            p_returns.append(pr)
            m_returns.append(mr)
            l_returns.append(plr)
            l_returns.append(mlr)
            top_returns.append(max(pr, mr))

        top_idx = sorted(range(len(top_returns)),
                         key=lambda k: top_returns[k],
                         reverse=True)[:n_top]
        p_returns = torch.stack(p_returns)[top_idx]
        m_returns = torch.stack(m_returns)[top_idx]
        l_returns = torch.stack(l_returns)[top_idx]

        r_hist.append(l_returns.mean())

        ep_steps = states.shape[0]
        s_mean = update_mean(states, s_mean, total_steps)
        s_stdv = update_std(states, s_stdv, total_steps)
        total_steps += ep_steps

        policy.state_means = s_mean
        policy.state_std = s_stdv
        do_rollout_partial = partial(do_rollout_train, env_name, policy,
                                     postprocess)

        W = W + (step_size / (n_delta * torch.cat(
            (p_returns, m_returns)).std() + 1e-6)) * torch.sum(
                (p_returns - m_returns) * deltas[top_idx].T, dim=1)

    pool.terminate()
    torch.nn.utils.vector_to_parameters(W, policy.parameters())
    return policy, r_hist
Beispiel #5
0
def ars(env_name,
        policy,
        n_epochs,
        env_config={},
        n_workers=8,
        step_size=.02,
        n_delta=32,
        n_top=16,
        exp_noise=0.03,
        zero_policy=True,
        learn_means=True,
        postprocess=postprocess_default):
    torch.autograd.set_grad_enabled(False)
    """
    Augmented Random Search
    https://arxiv.org/pdf/1803.07055

    Args:

    Returns:

    Example:
    """

    proc_list = []
    master_pipe_list = []

    for i in range(n_workers):
        master_con, worker_con = Pipe()
        proc = Process(target=worker_fn,
                       args=(worker_con, env_name, env_config, policy,
                             postprocess))
        proc.start()
        proc_list.append(proc)
        master_pipe_list.append(master_con)

    W = torch.nn.utils.parameters_to_vector(policy.parameters())
    n_param = W.shape[0]

    if zero_policy:
        W = torch.zeros_like(W)

    env = gym.make(env_name, **env_config)
    s_mean = policy.state_means
    s_std = policy.state_std
    total_steps = 0
    env.close()

    r_hist = []
    lr_hist = []

    exp_dist = torch.distributions.Normal(torch.zeros(n_delta, n_param),
                                          torch.ones(n_delta, n_param))

    for epoch in range(n_epochs):

        deltas = exp_dist.sample()
        pm_W = torch.cat((W + (deltas * exp_noise), W - (deltas * exp_noise)))

        for i, Ws in enumerate(pm_W):
            master_pipe_list[i % n_workers].send((Ws, s_mean, s_std))

        results = []
        for i, _ in enumerate(pm_W):
            results.append(master_pipe_list[i % n_workers].recv())

        states = torch.empty(0)
        p_returns = []
        m_returns = []
        l_returns = []
        top_returns = []

        for p_result, m_result in zip(results[:n_delta], results[n_delta:]):
            ps, pr, plr = p_result
            ms, mr, mlr = m_result

            states = torch.cat((states, ms, ps), dim=0)
            p_returns.append(pr)
            m_returns.append(mr)
            l_returns.append(plr)
            l_returns.append(mlr)
            top_returns.append(max(pr, mr))

        top_idx = sorted(range(len(top_returns)),
                         key=lambda k: top_returns[k],
                         reverse=True)[:n_top]
        p_returns = torch.stack(p_returns)[top_idx]
        m_returns = torch.stack(m_returns)[top_idx]
        l_returns = torch.stack(l_returns)[top_idx]

        lr_hist.append(l_returns.mean())
        r_hist.append((p_returns.mean() + m_returns.mean()) / 2)

        ep_steps = states.shape[0]
        s_mean = update_mean(states, s_mean, total_steps)
        s_std = update_std(states, s_std, total_steps)
        total_steps += ep_steps

        if epoch % 5 == 0:
            print(
                f"epoch: {epoch}, reward: {lr_hist[-1].item()}, processed reward: {r_hist[-1].item()} "
            )

        W = W + (step_size / (n_delta * torch.cat(
            (p_returns, m_returns)).std() + 1e-6)) * torch.sum(
                (p_returns - m_returns) * deltas[top_idx].T, dim=1)

    for pipe in master_pipe_list:
        pipe.send("STOP")
    policy.state_means = s_mean
    policy.state_std = s_std
    torch.nn.utils.vector_to_parameters(W, policy.parameters())
    return policy, r_hist, lr_hist
Beispiel #6
0
    def learn(self, n_epochs):
        torch.autograd.set_grad_enabled(False)

        proc_list = []
        master_pipe_list = []
        learn_start_idx = copy.copy(self.total_epochs)

        for i in range(self.n_workers):
            master_con, worker_con = Pipe()
            proc = Process(target=worker_fn,
                           args=(worker_con, self.env_name, self.env_config,
                                 self.policy, self.postprocessor, self.seed))
            proc.start()
            proc_list.append(proc)
            master_pipe_list.append(master_con)

        W = torch.nn.utils.parameters_to_vector(self.policy.parameters())
        n_param = W.shape[0]

        torch.manual_seed(self.seed)
        exp_dist = torch.distributions.Normal(
            torch.zeros(self.n_delta, n_param),
            torch.ones(self.n_delta, n_param))

        for _ in range(n_epochs):

            deltas = exp_dist.sample()
            pm_W = torch.cat(
                (W + (deltas * self.exp_noise), W - (deltas * self.exp_noise)))

            for i, Ws in enumerate(pm_W):
                master_pipe_list[i % self.n_workers].send(
                    (Ws, self.policy.state_means, self.policy.state_std))

            results = []
            for i, _ in enumerate(pm_W):
                results.append(master_pipe_list[i % self.n_workers].recv())

            states = torch.empty(0)
            p_returns = []
            m_returns = []
            l_returns = []
            top_returns = []

            for p_result, m_result in zip(results[:self.n_delta],
                                          results[self.n_delta:]):
                ps, pr, plr = p_result
                ms, mr, mlr = m_result

                states = torch.cat((states, ms, ps), dim=0)
                p_returns.append(pr)
                m_returns.append(mr)
                l_returns.append(plr)
                l_returns.append(mlr)
                top_returns.append(max(pr, mr))

            top_idx = sorted(range(len(top_returns)),
                             key=lambda k: top_returns[k],
                             reverse=True)[:self.n_top]
            p_returns = torch.stack(p_returns)[top_idx]
            m_returns = torch.stack(m_returns)[top_idx]
            l_returns = torch.stack(l_returns)[top_idx]

            self.lr_hist.append(l_returns.mean())
            self.r_hist.append((p_returns.mean() + m_returns.mean()) / 2)

            ep_steps = states.shape[0]
            self.policy.state_means = update_mean(states,
                                                  self.policy.state_means,
                                                  self.total_steps)
            self.policy.state_std = update_std(states, self.policy.state_std,
                                               self.total_steps)

            self.total_steps += ep_steps
            self.total_epochs += 1

            W = W + (self.step_size / (self.n_delta * torch.cat(
                (p_returns, m_returns)).std() + 1e-6)) * torch.sum(
                    (p_returns - m_returns) * deltas[top_idx].T, dim=1)

        for pipe in master_pipe_list:
            pipe.send("STOP")
        for proc in proc_list:
            proc.join()

        torch.nn.utils.vector_to_parameters(W, self.policy.parameters())
        return self.lr_hist[learn_start_idx:]
Beispiel #7
0
def ppo_visit(
        env_name,
        total_steps,
        model,
        vc=.01,
        replay_buf_size=int(5e4),
        act_std_schedule=(0.7,),
        epoch_batch_size=2048,
        gamma=0.99,
        lam=0.95,
        eps=0.2,
        seed=0,
        entropy_coef=0.0,
        sgd_batch_size=1024,
        lr_schedule=(3e-4,),
        sgd_epochs=10,
        target_kl=float('inf'),
        val_coef=.5,
        clip_val=True,
        env_no_term_steps=0,
        use_gpu=False,
        reward_stop=None,
        normalize_return=True,
        normalize_obs=True,
        normalize_adv=True,
        env_config={}
):
    """
    Implements proximal policy optimization with clipping

    Args:
        env_name: name of the openAI gym environment to solve
        total_steps: number of timesteps to run the PPO for
        model: model from seagul.rl.models. Contains policy and value fn
        act_std_schedule: schedule to set the variance of the policy. Will linearly interpolate values
        epoch_batch_size: number of environment steps to take per batch, total steps will be num_epochs*epoch_batch_size
        seed: seed for all the rngs
        gamma: discount applied to future rewards, usually close to 1
        lam: lambda for the Advantage estimation, usually close to 1
        eps: epsilon for the clipping, usually .1 or .2
        sgd_batch_size: batch size for policy updates
        sgd_batch_size: batch size for value function updates
        lr_schedule: learning rate for policy pol_optimizer
        sgd_epochs: how many epochs to use for each policy update
        val_epochs: how many epochs to use for each value update
        target_kl: max KL before breaking
        use_gpu:  want to use the GPU? set to true
        reward_stop: reward value to stop if we achieve
        normalize_return: should we normalize the return?
        env_config: dictionary containing kwargs to pass to your the environment

    Returns:
        model: trained model
        avg_reward_hist: list with the average reward per episode at each epoch
        var_dict: dictionary with all locals, for logging/debugging purposes

    Example:
        from seagul.rl.algos import ppo
        from seagul.nn import MLP
        from seagul.rl.models import PPOModel
        import torch

        input_size = 3
        output_size = 1
        layer_size = 64
        num_layers = 2

        policy = MLP(input_size, output_size, num_layers, layer_size)
        value_fn = MLP(input_size, 1, num_layers, layer_size)
        model = PPOModel(policy, value_fn)

        model, rews, var_dict = ppo("Pendulum-v0", 10000, model)

    """

    # init everything
    # ==============================================================================
    torch.set_num_threads(1)

    env = gym.make(env_name, **env_config)
    if isinstance(env.action_space, gym.spaces.Box):
        act_size = env.action_space.shape[0]
        act_dtype = torch.double
    else:
        raise NotImplementedError("trying to use unsupported action space", env.action_space)


    replay_buf = ReplayBuffer(env.observation_space.shape[0], act_size, replay_buf_size)

    actstd_lookup = make_schedule(act_std_schedule, total_steps)
    lr_lookup = make_schedule(lr_schedule, total_steps)

    model.action_var = actstd_lookup(0)
    sgd_lr = lr_lookup(0)

    obs_size = env.observation_space.shape[0]
    obs_mean = torch.zeros(obs_size)
    obs_std = torch.ones(obs_size)
    rew_mean = torch.zeros(1)
    rew_std = torch.ones(1)

    # copy.deepcopy broke for me with older version of torch. Using pickle for this is weird but works fine
    old_model = pickle.loads(pickle.dumps(model))

    # seed all our RNGs
    env.seed(seed)
    torch.manual_seed(seed)
    np.random.seed(seed)

    # set defaults, and decide if we are using a GPU or not
    use_cuda = torch.cuda.is_available() and use_gpu
    device = torch.device("cuda:0" if use_cuda else "cpu")

    # init logging stuff
    raw_rew_hist = []
    val_loss_hist = []
    pol_loss_hist = []
    progress_bar = tqdm.tqdm(total=total_steps)
    cur_total_steps = 0
    progress_bar.update(0)
    early_stop = False

    # Train until we hit our total steps or reach our reward threshold
    # ==============================================================================
    while cur_total_steps < total_steps:
        pol_opt = torch.optim.Adam(model.policy.parameters(), lr=sgd_lr)
        val_opt = torch.optim.Adam(model.value_fn.parameters(), lr=sgd_lr)

        batch_obs = torch.empty(0)
        batch_act = torch.empty(0)
        batch_adv = torch.empty(0)
        batch_discrew = torch.empty(0)
        cur_batch_steps = 0

        # Bail out if we have met out reward threshold
        if len(raw_rew_hist) > 2 and reward_stop:
            if raw_rew_hist[-1] >= reward_stop and raw_rew_hist[-2] >= reward_stop:
                early_stop = True
                break

        # construct batch data from rollouts
        # ==============================================================================
        while cur_batch_steps < epoch_batch_size:
            ep_obs, ep_act, ep_rew, ep_steps, ep_term = do_rollout(env, model, env_no_term_steps)

            raw_rew_hist.append(sum(ep_rew).item())


            for i, obs in enumerate(ep_obs):
                ep_rew[i] -= (np.min(np.linalg.norm(obs - replay_buf.obs1_buf, axis=1)))*vc

            replay_buf.store(ep_obs, ep_obs, ep_act, ep_rew, ep_rew)

            batch_obs = torch.cat((batch_obs, ep_obs[:-1]))
            batch_act = torch.cat((batch_act, ep_act[:-1]))

            if not ep_term:
                ep_rew[-1] = model.value_fn(ep_obs[-1]).detach()

            ep_discrew = discount_cumsum(ep_rew, gamma)

            if normalize_return:
                rew_mean = update_mean(batch_discrew, rew_mean, cur_total_steps)
                rew_std = update_std(ep_discrew, rew_std, cur_total_steps)
                ep_discrew = ep_discrew / (rew_std + 1e-6)

            batch_discrew = torch.cat((batch_discrew, ep_discrew[:-1]))

            ep_val = model.value_fn(ep_obs)

            deltas = ep_rew[:-1] + gamma * ep_val[1:] - ep_val[:-1]
            ep_adv = discount_cumsum(deltas.detach(), gamma * lam)
            batch_adv = torch.cat((batch_adv, ep_adv))

            cur_batch_steps += ep_steps
            cur_total_steps += ep_steps

        # make sure our advantages are zero mean and unit variance
        if normalize_adv:
            #adv_mean = update_mean(batch_adv, adv_mean, cur_total_steps)
            #adv_var = update_std(batch_adv, adv_var, cur_total_steps)
            batch_adv = (batch_adv - batch_adv.mean()) / (batch_adv.std() + 1e-6)


        num_mbatch = int(batch_obs.shape[0] / sgd_batch_size)
        # Update the policy using the PPO loss
        for pol_epoch in range(sgd_epochs):
            for i in range(num_mbatch):
                # policy update
                # ========================================================================
                cur_sample = i * sgd_batch_size

                # Transfer to GPU (if GPU is enabled, else this does nothing)
                local_obs = batch_obs[cur_sample:cur_sample + sgd_batch_size]
                local_act = batch_act[cur_sample:cur_sample + sgd_batch_size]
                local_adv = batch_adv[cur_sample:cur_sample + sgd_batch_size]
                local_val = batch_discrew[cur_sample:cur_sample + sgd_batch_size]

                # Compute the loss
                logp = model.get_logp(local_obs, local_act).reshape(-1, act_size)
                old_logp = old_model.get_logp(local_obs, local_act).reshape(-1, act_size)
                mean_entropy = -(logp*torch.exp(logp)).mean()

                r = torch.exp(logp - old_logp)
                clip_r = torch.clamp(r, 1 - eps, 1 + eps)

                pol_loss = -torch.min(r * local_adv, clip_r * local_adv).mean() - entropy_coef*mean_entropy

                approx_kl = ((logp - old_logp)**2).mean()
                if approx_kl > target_kl:
                    break

                pol_opt.zero_grad()
                pol_loss.backward()
                pol_opt.step()

                # value_fn update
                # ========================================================================
                val_preds = model.value_fn(local_obs)
                if clip_val:
                    old_val_preds = old_model.value_fn(local_obs)
                    val_preds_clipped = old_val_preds + torch.clamp(val_preds - old_val_preds, -eps, eps)
                    val_loss1 = (val_preds_clipped - local_val)**2
                    val_loss2 = (val_preds - local_val)**2
                    val_loss = val_coef*torch.max(val_loss1, val_loss2).mean()
                else:
                    val_loss = val_coef*((val_preds - local_val) ** 2).mean()

                val_opt.zero_grad()
                val_loss.backward()
                val_opt.step()

        # update observation mean and variance

        if normalize_obs:
            obs_mean = update_mean(batch_obs, obs_mean, cur_total_steps)
            obs_std = update_std(batch_obs, obs_std, cur_total_steps)
            model.policy.state_means = obs_mean
            model.value_fn.state_means = obs_mean
            model.policy.state_std = obs_std
            model.value_fn.state_std = obs_std

        model.action_std = actstd_lookup(cur_total_steps)
        sgd_lr = lr_lookup(cur_total_steps)

        old_model = pickle.loads(pickle.dumps(model))
        val_loss_hist.append(val_loss)
        pol_loss_hist.append(pol_loss)

        progress_bar.update(cur_batch_steps)

    progress_bar.close()
    return model, raw_rew_hist, locals()