Exemplo n.º 1
0
def setup_pytorch_for_mpi():
    """
    Avoid slowdowns caused by each separate process's PyTorch using
    more than its fair share of CPU resources.
    """
    if torch.get_num_threads() == 1:
        return
    fair_num_threads = max(int(torch.get_num_threads() / num_procs()), 1)
    torch.set_num_threads(fair_num_threads)
Exemplo n.º 2
0
def trpo(
        env_fn,
        actor_critic=core.ActorCritic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=4000,
        epochs=50,
        gamma=0.99,
        delta=0.01,
        vf_lr=1e-3,
        train_v_iters=80,
        damping_coeff=0.1,
        cg_iters=10,
        backtrack_iters=10,
        backtrack_coeff=0.8,
        lam=0.97,
        max_ep_len=1000,
        logger_kwargs=dict(),
        save_freq=10,
        algo="trpo",
):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: The agent's main model which for state ``x`` and
            action, ``a`` returns the following outputs:

            ============  ================  ========================================
            Symbol        Shape             Description
            ============  ================  ========================================
            ``pi``        (batch, act_dim)  | Samples actions from policy given
                                            | states.
            ``logp``      (batch,)          | Gives log probability, according to
                                            | the policy, of taking actions ``a``
                                            | in states ``x``.
            ``logp_pi``   (batch,)          | Gives log probability, according to
                                            | the policy, of the action sampled by
                                            | ``pi``.
            ``info``      N/A               | A dict of any intermediate quantities
                                            | (from calculating the policy or log
                                            | probabilities) which are needed for
                                            | analytically computing KL divergence.
                                            | (eg sufficient statistics of the
                                            | distributions)
            ``info_phs``  N/A               | A dict of placeholders for old values
                                            | of the entries in ``info``.
            ``d_kl``      ()                | The mean KL
                                            | divergence between the current policy
                                            | (``pi``) and the old policy (as
                                            | specified by the inputs to
                                            | ``info``) over the batch of
                                            | states given in ``x``.
            ``v``         (batch,)          | Gives the value estimate for states
                                            | in ``x``. (Critical: make sure
                                            | to flatten this!)
            ============  ================  ========================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_critic
            function you provided to TRPO.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs)
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs of interaction (equivalent to
            number of policy updates) to perform.

        gamma (float): Discount factor. (Always between 0 and 1.)

        delta (float): KL-divergence limit for TRPO / NPG update.
            (Should be small for stability. Values like 0.01, 0.05.)

        vf_lr (float): Learning rate for value function optimizer.

        train_v_iters (int): Number of gradient descent steps to take on
            value function per epoch.

        damping_coeff (float): Artifact for numerical stability, should be
            smallish. Adjusts Hessian-vector product calculation:

            .. math:: Hv \\rightarrow (\\alpha I + H)v

            where :math:`\\alpha` is the damping coefficient.
            Probably don't play with this hyperparameter.

        cg_iters (int): Number of iterations of conjugate gradient to perform.
            Increasing this will lead to a more accurate approximation
            to :math:`H^{-1} g`, and possibly slightly-improved performance,
            but at the cost of slowing things down.

            Also probably don't play with this hyperparameter.

        backtrack_iters (int): Maximum number of steps allowed in the
            backtracking line search. Since the line search usually doesn't
            backtrack, and usually only steps back once when it does, this
            hyperparameter doesn't often matter.

        backtrack_coeff (float): How far back to step during backtracking line
            search. (Always between 0 and 1, usually above 0.5.)

        lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,
            close to 1.)

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

        algo: Either 'trpo' or 'npg': this code supports both, since they are
            almost the same.

    """

    setup_pytorch_for_mpi()

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    seed += 10000 * proc_id()
    torch.manual_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    # Share information about action space with policy architecture
    ac_kwargs["action_space"] = env.action_space

    # Main model
    actor_critic = actor_critic(in_features=obs_dim[0], **ac_kwargs)

    # Experience buffer
    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    if isinstance(env.action_space, Box):
        info_shapes = {
            "old_mu": [env.action_space.shape[-1]],
            "old_log_std": [env.action_space.shape[-1]],
        }
    else:
        info_shapes = {"old_logits": [env.action_space.n]}
    buf = GAEBuffer(obs_dim, act_dim, local_steps_per_epoch, info_shapes,
                    gamma, lam)

    # Count variables
    var_counts = tuple(
        core.count_vars(module)
        for module in [actor_critic.policy, actor_critic.value_function])
    logger.log("\nNumber of parameters: \t pi: %d, \t v: %d\n" % var_counts)

    # Optimizer for value function
    train_vf = torch.optim.Adam(actor_critic.value_function.parameters(),
                                lr=vf_lr)

    # Sync params across processes
    sync_all_params(actor_critic.parameters())

    def cg(Ax, b):
        """
        Conjugate gradient algorithm
        (see https://en.wikipedia.org/wiki/Conjugate_gradient_method)
        """
        x = torch.zeros_like(b)
        r = b  # Note: should be 'b - Ax(x)', but for x=0, Ax(x)=0. Change if doing warm start.
        p = b
        r_dot_old = torch.dot(r, r)
        for _ in range(cg_iters):
            z = Ax(p)
            alpha = r_dot_old / (torch.dot(p, z) + EPS)
            x += alpha * p
            r -= alpha * z
            r_dot_new = torch.dot(r, r)
            p = r + (r_dot_new / r_dot_old) * p
            r_dot_old = r_dot_new
        return x

    def update():
        inputs = [torch.Tensor(x) for x in buf.get()]
        obs, act, adv, ret, logp_old = inputs[:-len(buf.sorted_info_keys)]
        policy_args = dict(
            zip(buf.sorted_info_keys, inputs[-len(buf.sorted_info_keys):]))

        # Main outputs from computation graph
        _, logp, _, _, d_kl, v = actor_critic(obs, act, **policy_args)

        # Prepare hessian func, gradient eval
        ratio = (logp - logp_old).exp()  # pi(a|s) / pi_old(a|s)
        pi_l_old = -(ratio * adv).mean()
        v_l_old = F.mse_loss(v, ret)

        g = core.flat_grad(pi_l_old,
                           actor_critic.policy.parameters(),
                           retain_graph=True)
        g = torch.from_numpy(mpi_avg(g.numpy()))
        pi_l_old = mpi_avg(pi_l_old.item())

        def Hx(x):
            hvp = core.hessian_vector_product(d_kl, actor_critic.policy, x)
            if damping_coeff > 0:
                hvp += damping_coeff * x
            return torch.from_numpy(mpi_avg(hvp.numpy()))

        # Core calculations for TRPO or NPG
        x = cg(Hx, g)
        alpha = torch.sqrt(2 * delta / (torch.dot(x, Hx(x)) + EPS))
        old_params = parameters_to_vector(actor_critic.policy.parameters())

        def set_and_eval(step):
            vector_to_parameters(old_params - alpha * x * step,
                                 actor_critic.policy.parameters())
            _, logp, _, _, d_kl = actor_critic.policy(obs, act, **policy_args)
            ratio = (logp - logp_old).exp()
            pi_loss = -(ratio * adv).mean()
            return mpi_avg(d_kl.item()), mpi_avg(pi_loss.item())

        if algo == "npg":
            kl, pi_l_new = set_and_eval(step=1.0)

        elif algo == "trpo":
            for j in range(backtrack_iters):
                kl, pi_l_new = set_and_eval(step=backtrack_coeff**j)
                if kl <= delta and pi_l_new <= pi_l_old:
                    logger.log(
                        "Accepting new params at step %d of line search." % j)
                    logger.store(BacktrackIters=j)
                    break

                if j == backtrack_iters - 1:
                    logger.log("Line search failed! Keeping old params.")
                    logger.store(BacktrackIters=j)
                    kl, pi_l_new = set_and_eval(step=0.0)

        # Value function updates
        for _ in range(train_v_iters):
            v = actor_critic.value_function(obs)
            v_loss = F.mse_loss(v, ret)

            # Value function gradient step
            train_vf.zero_grad()
            v_loss.backward()
            average_gradients(train_vf.param_groups)
            train_vf.step()

        v = actor_critic.value_function(obs)
        v_l_new = F.mse_loss(v, ret)

        # Log changes from update
        logger.store(
            LossPi=pi_l_old,
            LossV=v_l_old,
            KL=kl,
            DeltaLossPi=(pi_l_new - pi_l_old),
            DeltaLossV=(v_l_new - v_l_old),
        )

    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(epochs):
        actor_critic.eval()
        for t in range(local_steps_per_epoch):
            a, _, logp_t, info_t, _, v_t = actor_critic(
                torch.Tensor(o.reshape(1, -1)))

            # save and log
            buf.store(
                o,
                a.detach().numpy(),
                r,
                v_t.item(),
                logp_t.detach().numpy(),
                core.values_as_sorted_list(info_t),
            )
            logger.store(VVals=v_t)

            o, r, d, _ = env.step(a.detach().numpy()[0])
            ep_ret += r
            ep_len += 1

            terminal = d or (ep_len == max_ep_len)
            if terminal or (t == local_steps_per_epoch - 1):
                if not (terminal):
                    print("Warning: trajectory cut off by epoch at %d steps." %
                          ep_len)
                # if trajectory didn't reach terminal state, bootstrap value target
                last_val = (r if d else actor_critic.value_function(
                    torch.Tensor(o.reshape(1, -1))).item())
                buf.finish_path(last_val)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({"env": env}, actor_critic, None)

        # Perform TRPO or NPG update!
        actor_critic.train()
        update()

        # Log info about epoch
        logger.log_tabular("Epoch", epoch)
        logger.log_tabular("EpRet", with_min_and_max=True)
        logger.log_tabular("EpLen", average_only=True)
        logger.log_tabular("VVals", with_min_and_max=True)
        logger.log_tabular("TotalEnvInteracts", (epoch + 1) * steps_per_epoch)
        logger.log_tabular("LossPi", average_only=True)
        logger.log_tabular("LossV", average_only=True)
        logger.log_tabular("DeltaLossPi", average_only=True)
        logger.log_tabular("DeltaLossV", average_only=True)
        logger.log_tabular("KL", average_only=True)
        if algo == "trpo":
            logger.log_tabular("BacktrackIters", average_only=True)
        logger.log_tabular("Time", time.time() - start_time)
        logger.dump_tabular()
Exemplo n.º 3
0
def ppo(
        env_fn,
        actor_critic=core.ActorCritic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=4000,
        epochs=50,
        gamma=0.99,
        clip_ratio=0.2,
        pi_lr=3e-4,
        vf_lr=1e-3,
        train_pi_iters=80,
        train_v_iters=80,
        lam=0.97,
        max_ep_len=1000,
        target_kl=0.01,
        logger_kwargs=dict(),
        save_freq=10,
        use_gpu=False,
        gpu_parallel=False,
):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: The agent's main model which is composed of
            the policy and value function model, where the policy takes
            some state, ``x`` and action ``a``, and value function takes
            the state ``x``. The model returns a tuple of:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``pi``       (batch, act_dim)  | Samples actions from policy given
                                           | states.
            ``logp``     (batch,)          | Gives log probability, according to
                                           | the policy, of taking actions ``a``
                                           | in states ``x``.
            ``logp_pi``  (batch,)          | Gives log probability, according to
                                           | the policy, of the action sampled by
                                           | ``pi``.
            ``v``        (batch,)          | Gives the value estimate for states
                                           | in ``x``. (Critical: make sure
                                           | to flatten this via .squeeze()!)
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_critic
            class you provided to PPO.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs)
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs of interaction (equivalent to
            number of policy updates) to perform.

        gamma (float): Discount factor. (Always between 0 and 1.)

        clip_ratio (float): Hyperparameter for clipping in the policy objective.
            Roughly: how far can the new policy go from the old policy while
            still profiting (improving the objective function)? The new policy
            can still go farther than the clip_ratio says, but it doesn't help
            on the objective anymore. (Usually small, 0.1 to 0.3.)

        pi_lr (float): Learning rate for policy optimizer.

        vf_lr (float): Learning rate for value function optimizer.

        train_pi_iters (int): Maximum number of gradient descent steps to take
            on policy loss per epoch. (Early stopping may cause optimizer
            to take fewer than this.)

        train_v_iters (int): Number of gradient descent steps to take on
            value function per epoch.

        lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,
            close to 1.)

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        target_kl (float): Roughly what KL divergence we think is appropriate
            between new and old policies after an update. This will get used
            for early stopping. (Usually small, 0.01 or 0.05.)

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    seed += 10000 * proc_id()
    torch.manual_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space

    # Main model
    actor_critic = actor_critic(in_features=obs_dim[0], **ac_kwargs)
    # actor_critic = torch.nn.DataParallel(actor_critic).to(device)
    # gpu是否使用
    # device = torch.device("cpu" if USE_DEVICE=="cpu" else "cuda")
    if torch.cuda.is_available():
        device = torch.device("cuda" if use_gpu else "cpu")
        if gpu_parallel:
            actor_critic = torch.nn.DataParallel(actor_critic)
    else:
        use_gpu = False
        use_parallel = False
        device = torch.device("cpu")
    actor_critic = actor_critic.to(device)
    # Experience buffer
    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    buf = PPOBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam)

    # Count variables
    var_counts = tuple(
        core.count_vars(module)
        for module in [actor_critic.policy, actor_critic.value_function])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # Optimizers
    train_pi = torch.optim.Adam(actor_critic.policy.parameters(), lr=pi_lr)
    train_v = torch.optim.Adam(actor_critic.value_function.parameters(),
                               lr=vf_lr)

    # Sync params across processes
    sync_all_params(actor_critic.parameters())

    def update():
        temp_get = buf.get()
        obs, act, adv, ret, logp_old = [
            torch.Tensor(x).to(device) for x in temp_get
        ]

        # Training policy
        _, logp, _ = actor_critic.policy(obs, act)
        ratio = (logp - logp_old).exp()
        min_adv = torch.where(adv > 0, (1 + clip_ratio) * adv,
                              (1 - clip_ratio) * adv)

        pi_l_old = -(torch.min(ratio * adv, min_adv)).mean()
        ent = (-logp).mean()  # a sample estimate for entropy

        for i in range(train_pi_iters):
            # Output from policy function graph
            _, logp, _ = actor_critic.policy(obs, act)
            # PPO policy objective
            ratio = (logp - logp_old).exp()
            min_adv = torch.where(adv > 0, (1 + clip_ratio) * adv,
                                  (1 - clip_ratio) * adv)

            pi_loss = -(torch.min(ratio * adv, min_adv)).mean()

            # Policy gradient step
            train_pi.zero_grad()
            pi_loss.backward()
            average_gradients(train_pi.param_groups)
            train_pi.step()

            _, logp, _ = actor_critic.policy(obs, act)
            kl = (logp_old - logp).mean()
            kl = mpi_avg(kl.item())
            if kl > 1.5 * target_kl:
                logger.log(
                    'Early stopping at step %d due to reaching max kl.' % i)
                break
        logger.store(StopIter=i)

        # Training value function
        v = actor_critic.value_function(obs)
        v_l_old = F.mse_loss(v, ret)
        for _ in range(train_v_iters):
            # Output from value function graph
            v = actor_critic.value_function(obs)
            # PPO value function objective
            v_loss = F.mse_loss(v, ret)

            # Value function gradient step
            train_v.zero_grad()
            v_loss.backward()
            average_gradients(train_v.param_groups)
            train_v.step()

        # Log changes from update
        _, logp, _, v = actor_critic(obs, act)
        ratio = (logp - logp_old).exp()
        min_adv = torch.where(adv > 0, (1 + clip_ratio) * adv,
                              (1 - clip_ratio) * adv)
        pi_l_new = -(torch.min(ratio * adv, min_adv)).mean()
        v_l_new = F.mse_loss(v, ret)
        kl = (logp_old - logp).mean()  # a sample estimate for KL-divergence
        clipped = (ratio > (1 + clip_ratio)) | (ratio < (1 - clip_ratio))
        cf = (clipped.float()).mean()
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     ClipFrac=cf,
                     DeltaLossPi=(pi_l_new - pi_l_old),
                     DeltaLossV=(v_l_new - v_l_old))

    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(epochs):
        actor_critic.eval()
        for t in range(local_steps_per_epoch):
            a, _, logp_t, v_t = actor_critic(
                torch.Tensor(o.reshape(1, -1)).to(device))

            # save and log
            buf.store(o,
                      a.cpu().detach().numpy(), r, v_t.item(),
                      logp_t.cpu().detach().numpy())
            logger.store(VVals=v_t)

            o, r, d, _ = env.step(a.cpu().detach().numpy()[0])
            ep_ret += r
            ep_len += 1

            terminal = d or (ep_len == max_ep_len)
            if terminal or (t == local_steps_per_epoch - 1):
                if not (terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.' %
                          ep_len)
                # if trajectory didn't reach terminal state, bootstrap value target
                last_val = r if d else actor_critic.value_function(
                    torch.Tensor(o.reshape(1, -1)).to(device)).item()
                buf.finish_path(last_val)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, actor_critic, None)

        # Perform PPO update!
        actor_critic.train()
        update()

        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('ClipFrac', average_only=True)
        logger.log_tabular('StopIter', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()
Exemplo n.º 4
0
def vpg(env_fn,
        actor_critic=core.ActorCritic,
        ac_kwargs=dict(),
        seed=0,
        steps_per_epoch=4000,
        epochs=50,
        gamma=0.99,
        pi_lr=3e-4,
        vf_lr=1e-3,
        train_v_iters=80,
        lam=0.97,
        max_ep_len=1000,
        logger_kwargs=dict(),
        save_freq=10):
    """

    Args:
        env_fn : A function which creates a copy of the environment.
            The environment must satisfy the OpenAI Gym API.

        actor_critic: The agent's main model which is composed of
            the policy and value function model, where the policy takes
            some state, ``x``, and action, ``a``, and value function takes
            the state ``x`` and returns a tuple of:

            ===========  ================  ======================================
            Symbol       Shape             Description
            ===========  ================  ======================================
            ``pi``       (batch, act_dim)  | Samples actions from policy given
                                           | states.
            ``logp``     (batch,)          | Gives log probability, according to
                                           | the policy, of taking actions ``a``
                                           | in states ``x``.
            ``logp_pi``  (batch,)          | Gives log probability, according to
                                           | the policy, of the action sampled by
                                           | ``pi``.
            ``v``        (batch,)          | Gives the value estimate for states
                                           | in ``x``. (Critical: make sure
                                           | to flatten this via .item()!)
            ===========  ================  ======================================

        ac_kwargs (dict): Any kwargs appropriate for the actor_critic
            class you provided to VPG.

        seed (int): Seed for random number generators.

        steps_per_epoch (int): Number of steps of interaction (state-action pairs)
            for the agent and the environment in each epoch.

        epochs (int): Number of epochs of interaction (equivalent to
            number of policy updates) to perform.

        gamma (float): Discount factor. (Always between 0 and 1.)

        pi_lr (float): Learning rate for policy optimizer.

        vf_lr (float): Learning rate for value function optimizer.

        train_v_iters (int): Number of gradient descent steps to take on
            value function per epoch.

        lam (float): Lambda for GAE-Lambda. (Always between 0 and 1,
            close to 1.)

        max_ep_len (int): Maximum length of trajectory / episode / rollout.

        logger_kwargs (dict): Keyword args for EpochLogger.

        save_freq (int): How often (in terms of gap between epochs) to save
            the current policy and value function.

    """

    logger = EpochLogger(**logger_kwargs)
    logger.save_config(locals())

    seed += 10000 * proc_id()
    torch.manual_seed(seed)
    np.random.seed(seed)

    env = env_fn()
    obs_dim = env.observation_space.shape
    act_dim = env.action_space.shape

    # Share information about action space with policy architecture
    ac_kwargs['action_space'] = env.action_space

    # Main model
    actor_critic = actor_critic(in_features=obs_dim[0], **ac_kwargs)

    # Experience buffer
    local_steps_per_epoch = int(steps_per_epoch / num_procs())
    buf = VPGBuffer(obs_dim, act_dim, local_steps_per_epoch, gamma, lam)

    # Count variables
    var_counts = tuple(
        core.count_vars(module)
        for module in [actor_critic.policy, actor_critic.value_function])
    logger.log('\nNumber of parameters: \t pi: %d, \t v: %d\n' % var_counts)

    # Optimizers
    train_pi = torch.optim.Adam(actor_critic.policy.parameters(), lr=pi_lr)
    train_v = torch.optim.Adam(actor_critic.value_function.parameters(),
                               lr=vf_lr)

    # Sync params across processes
    sync_all_params(actor_critic.parameters())

    def update():
        obs, act, adv, ret, logp_old = [torch.Tensor(x) for x in buf.get()]

        # Policy gradient step
        _, logp, _ = actor_critic.policy(obs, act)
        ent = (-logp).mean()  # a sample estimate for entropy

        # VPG policy objective
        pi_loss = -(logp * adv).mean()

        # Policy gradient step
        train_pi.zero_grad()
        pi_loss.backward()
        average_gradients(train_pi.param_groups)
        train_pi.step()

        # Value function learning
        v = actor_critic.value_function(obs)
        v_l_old = F.mse_loss(v, ret)
        for _ in range(train_v_iters):
            # Output from value function graph
            v = actor_critic.value_function(obs)
            # VPG value objective
            v_loss = F.mse_loss(v, ret)

            # Value function gradient step
            train_v.zero_grad()
            v_loss.backward()
            average_gradients(train_v.param_groups)
            train_v.step()

        # Log changes from update
        _, logp, _, v = actor_critic(obs, act)
        pi_l_new = -(logp * adv).mean()
        v_l_new = F.mse_loss(v, ret)
        kl = (logp_old - logp).mean()  # a sample estimate for KL-divergence
        logger.store(LossPi=pi_loss,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     DeltaLossPi=(pi_l_new - pi_loss),
                     DeltaLossV=(v_l_new - v_l_old))

    start_time = time.time()
    o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

    # Main loop: collect experience in env and update/log each epoch
    for epoch in range(epochs):
        actor_critic.eval()
        for t in range(local_steps_per_epoch):
            a, _, logp_t, v_t = actor_critic(torch.Tensor(o.reshape(1, -1)))

            # save and log
            buf.store(o,
                      a.detach().numpy(), r, v_t.item(),
                      logp_t.detach().numpy())
            logger.store(VVals=v_t)

            o, r, d, _ = env.step(a.detach().numpy()[0])
            ep_ret += r
            ep_len += 1

            terminal = d or (ep_len == max_ep_len)
            if terminal or (t == local_steps_per_epoch - 1):
                if not (terminal):
                    print('Warning: trajectory cut off by epoch at %d steps.' %
                          ep_len)
                # if trajectory didn't reach terminal state, bootstrap value target
                last_val = r if d else actor_critic.value_function(
                    torch.Tensor(o.reshape(1, -1))).item()
                buf.finish_path(last_val)
                if terminal:
                    # only save EpRet / EpLen if trajectory finished
                    logger.store(EpRet=ep_ret, EpLen=ep_len)
                o, r, d, ep_ret, ep_len = env.reset(), 0, False, 0, 0

        # Save model
        if (epoch % save_freq == 0) or (epoch == epochs - 1):
            logger.save_state({'env': env}, actor_critic, None)

        # Perform VPG update!
        actor_critic.train()
        update()

        # Log info about epoch
        logger.log_tabular('Epoch', epoch)
        logger.log_tabular('EpRet', with_min_and_max=True)
        logger.log_tabular('EpLen', average_only=True)
        logger.log_tabular('VVals', with_min_and_max=True)
        logger.log_tabular('TotalEnvInteracts', (epoch + 1) * steps_per_epoch)
        logger.log_tabular('LossPi', average_only=True)
        logger.log_tabular('LossV', average_only=True)
        logger.log_tabular('DeltaLossPi', average_only=True)
        logger.log_tabular('DeltaLossV', average_only=True)
        logger.log_tabular('Entropy', average_only=True)
        logger.log_tabular('KL', average_only=True)
        logger.log_tabular('Time', time.time() - start_time)
        logger.dump_tabular()