Esempio n. 1
0
    def update():
        obs, act, adv, ret, logp_old = [torch.Tensor(x) for x in buf.get()]

        # Policy gradient step
        _, logp, _ = actor_critic.policy(obs, act)
        ent = (-logp).mean()  # a sample estimate for entropy

        # VPG policy objective
        pi_loss = -(logp * adv).mean()

        # Policy gradient step
        train_pi.zero_grad()
        pi_loss.backward()
        average_gradients(train_pi.param_groups)
        train_pi.step()

        # Value function learning
        v = actor_critic.value_function(obs)
        v_l_old = F.mse_loss(v, ret)
        for _ in range(train_v_iters):
            # Output from value function graph
            v = actor_critic.value_function(obs)
            # VPG value objective
            v_loss = F.mse_loss(v, ret)

            # Value function gradient step
            train_v.zero_grad()
            v_loss.backward()
            average_gradients(train_v.param_groups)
            train_v.step()

        # Log changes from update
        _, logp, _, v = actor_critic(obs, act)
        pi_l_new = -(logp * adv).mean()
        v_l_new = F.mse_loss(v, ret)
        kl = (logp_old - logp).mean()  # a sample estimate for KL-divergence
        logger.store(
            LossPi=pi_loss,
            LossV=v_l_old,
            KL=kl,
            Entropy=ent,
            DeltaLossPi=(pi_l_new - pi_loss),
            DeltaLossV=(v_l_new - v_l_old),
        )
Esempio n. 2
0
    def update():
        inputs = [torch.Tensor(x) for x in buf.get()]
        obs, act, adv, ret, logp_old = inputs[:-len(buf.sorted_info_keys)]
        policy_args = dict(
            zip(buf.sorted_info_keys, inputs[-len(buf.sorted_info_keys):]))

        # Main outputs from computation graph
        _, logp, _, _, d_kl, v = actor_critic(obs, act, **policy_args)

        # Prepare hessian func, gradient eval
        ratio = (logp - logp_old).exp()  # pi(a|s) / pi_old(a|s)
        pi_l_old = -(ratio * adv).mean()
        v_l_old = F.mse_loss(v, ret)

        g = core.flat_grad(pi_l_old,
                           actor_critic.policy.parameters(),
                           retain_graph=True)
        g = torch.from_numpy(mpi_avg(g.numpy()))
        pi_l_old = mpi_avg(pi_l_old.item())

        def Hx(x):
            hvp = core.hessian_vector_product(d_kl, actor_critic.policy, x)
            if damping_coeff > 0:
                hvp += damping_coeff * x
            return torch.from_numpy(mpi_avg(hvp.numpy()))

        # Core calculations for TRPO or NPG
        x = cg(Hx, g)
        alpha = torch.sqrt(2 * delta / (torch.dot(x, Hx(x)) + EPS))
        old_params = parameters_to_vector(actor_critic.policy.parameters())

        def set_and_eval(step):
            vector_to_parameters(old_params - alpha * x * step,
                                 actor_critic.policy.parameters())
            _, logp, _, _, d_kl = actor_critic.policy(obs, act, **policy_args)
            ratio = (logp - logp_old).exp()
            pi_loss = -(ratio * adv).mean()
            return mpi_avg(d_kl.item()), mpi_avg(pi_loss.item())

        if algo == "npg":
            kl, pi_l_new = set_and_eval(step=1.0)

        elif algo == "trpo":
            for j in range(backtrack_iters):
                kl, pi_l_new = set_and_eval(step=backtrack_coeff**j)
                if kl <= delta and pi_l_new <= pi_l_old:
                    logger.log(
                        "Accepting new params at step %d of line search." % j)
                    logger.store(BacktrackIters=j)
                    break

                if j == backtrack_iters - 1:
                    logger.log("Line search failed! Keeping old params.")
                    logger.store(BacktrackIters=j)
                    kl, pi_l_new = set_and_eval(step=0.0)

        # Value function updates
        for _ in range(train_v_iters):
            v = actor_critic.value_function(obs)
            v_loss = F.mse_loss(v, ret)

            # Value function gradient step
            train_vf.zero_grad()
            v_loss.backward()
            average_gradients(train_vf.param_groups)
            train_vf.step()

        v = actor_critic.value_function(obs)
        v_l_new = F.mse_loss(v, ret)

        # Log changes from update
        logger.store(
            LossPi=pi_l_old,
            LossV=v_l_old,
            KL=kl,
            DeltaLossPi=(pi_l_new - pi_l_old),
            DeltaLossV=(v_l_new - v_l_old),
        )
    def update():
        temp_get = buf.get()
        obs, act, adv, ret, logp_old = [
            torch.Tensor(x).to(device) for x in temp_get
        ]

        # Training policy
        _, logp, _ = actor_critic.policy(obs, act)
        ratio = (logp - logp_old).exp()
        min_adv = torch.where(adv > 0, (1 + clip_ratio) * adv,
                              (1 - clip_ratio) * adv)

        pi_l_old = -(torch.min(ratio * adv, min_adv)).mean()
        ent = (-logp).mean()  # a sample estimate for entropy

        for i in range(train_pi_iters):
            # Output from policy function graph
            _, logp, _ = actor_critic.policy(obs, act)
            # PPO policy objective
            ratio = (logp - logp_old).exp()
            min_adv = torch.where(adv > 0, (1 + clip_ratio) * adv,
                                  (1 - clip_ratio) * adv)

            pi_loss = -(torch.min(ratio * adv, min_adv)).mean()

            # Policy gradient step
            train_pi.zero_grad()
            pi_loss.backward()
            average_gradients(train_pi.param_groups)
            train_pi.step()

            _, logp, _ = actor_critic.policy(obs, act)
            kl = (logp_old - logp).mean()
            kl = mpi_avg(kl.item())
            if kl > 1.5 * target_kl:
                logger.log(
                    'Early stopping at step %d due to reaching max kl.' % i)
                break
        logger.store(StopIter=i)

        # Training value function
        v = actor_critic.value_function(obs)
        v_l_old = F.mse_loss(v, ret)
        for _ in range(train_v_iters):
            # Output from value function graph
            v = actor_critic.value_function(obs)
            # PPO value function objective
            v_loss = F.mse_loss(v, ret)

            # Value function gradient step
            train_v.zero_grad()
            v_loss.backward()
            average_gradients(train_v.param_groups)
            train_v.step()

        # Log changes from update
        _, logp, _, v = actor_critic(obs, act)
        ratio = (logp - logp_old).exp()
        min_adv = torch.where(adv > 0, (1 + clip_ratio) * adv,
                              (1 - clip_ratio) * adv)
        pi_l_new = -(torch.min(ratio * adv, min_adv)).mean()
        v_l_new = F.mse_loss(v, ret)
        kl = (logp_old - logp).mean()  # a sample estimate for KL-divergence
        clipped = (ratio > (1 + clip_ratio)) | (ratio < (1 - clip_ratio))
        cf = (clipped.float()).mean()
        logger.store(LossPi=pi_l_old,
                     LossV=v_l_old,
                     KL=kl,
                     Entropy=ent,
                     ClipFrac=cf,
                     DeltaLossPi=(pi_l_new - pi_l_old),
                     DeltaLossV=(v_l_new - v_l_old))