def _optimize(self, obs, acts, advs, est_rs):

        self.obs, self.acts, self.advs, self.est_rs = obs, acts, advs, est_rs

        self.obs = Tensor(self.obs)
        self.acts = Tensor(self.acts)
        self.advs = Tensor(self.advs).unsqueeze(1)
        self.est_rs = Tensor(self.est_rs).unsqueeze(1)

        # Calculate Advantage & Normalize it
        self.advs = (self.advs - self.advs.mean()) / (self.advs.std() + 1e-8)

        # Surrogate loss with Entropy

        if self.continuous:
            mean, std, values = self.model(self.obs)

            dis = Normal(mean, std)

            log_prob = dis.log_prob(self.acts).sum(-1, keepdim=True)

            ent = dis.entropy().sum(-1, keepdim=True)

            probs_new = torch.exp(log_prob)
            probs_old = probs_new.detach() + 1e-8

        else:

            probs, values = self.model(self.obs)

            dis = F.softmax(probs, dim=1)

            self.acts = self.acts.long()

            probs_new = dis.gather(1, self.acts)
            probs_old = probs_new + 1e-8

            ent = -(dis.log() * dis).sum(-1)

        ratio = probs_new / probs_old

        surrogate_loss = -torch.mean(
            ratio * self.advs) - self.entropy_para * ent.mean()

        # criterion = torch.nn.MSELoss()
        # empty_value_loss = criterion( values, values.detach() )

        # Calculate the gradient of the surrogate loss
        self.model.zero_grad()
        surrogate_loss.backward()
        policy_gradient = parameters_to_vector([
            p.grad for p in self.model.policy_parameters()
        ]).squeeze(0).detach()

        # ensure gradient is not zero
        if policy_gradient.nonzero().size()[0]:
            # Use Conjugate gradient to calculate step direction
            step_direction = self.conjugate_gradient(-policy_gradient)
            # line search for step
            shs = .5 * step_direction.dot(
                self.hessian_vector_product(step_direction))

            lm = torch.sqrt(shs / self.max_kl)
            fullstep = step_direction / lm

            gdotstepdir = -policy_gradient.dot(step_direction)
            theta = self.linesearch(
                parameters_to_vector(self.model.policy_parameters()).detach(),
                fullstep, gdotstepdir / lm)
            # Update parameters of policy model
            old_model = copy.deepcopy(self.model)
            old_model.load_state_dict(self.model.state_dict())

            if any(np.isnan(theta.cpu().detach().numpy())):
                print("NaN detected. Skipping update...")
            else:
                # for param in self.model.policy_parameters():
                #     print(param)
                vector_to_parameters(theta, self.model.policy_parameters())

            kl_old_new = self.mean_kl_divergence(old_model)
            print('KL:{:10} , Entropy:{:10}'.format(kl_old_new.item(),
                                                    ent.mean().item()))

        else:
            print("Policy gradient is 0. Skipping update...")
            print(policy_gradient.shape)

        self.model.zero_grad()

        if self.continuous:
            _, _, values = self.model(self.obs)
        else:
            _, values = self.model(self.obs)

        criterion = torch.nn.MSELoss()
        critic_loss = self.value_loss_coeff * criterion(values, self.est_rs)
        critic_loss.backward()
        self.optim.step()
        print("MSELoss for Value Net:{}".format(critic_loss.item()))
Ejemplo n.º 2
0
    def _optimize(self, obs, acts, advs, est_rs):

        self.optim.zero_grad()
        
        obs  = Tensor( obs )
        acts = Tensor( acts )
        advs = Tensor( advs ).unsqueeze(1)
        est_rs = Tensor( est_rs ).unsqueeze(1)

        if self.continuous:
            mean, std, values = self.model( obs )
            with torch.no_grad():
                mean_old, std_old, _ = self.model_old( obs )

            dis = Normal(mean, std)
            dis_old = Normal(mean_old, std_old)
            
            log_prob     = dis.log_prob( acts ).sum( -1, keepdim=True )
            log_prob_old = dis_old.log_prob( acts ).sum( -1, keepdim=True )

            ent = dis.entropy().sum( -1, keepdim=True )

            ratio = torch.exp(log_prob - log_prob_old)            

        else:

            probs, values = self.model(obs)
            with torch.no_grad():
                probs_old, _ = self.model_old(obs)

            dis = F.softmax( probs, dim = 1 )
            dis_old = F.softmax(probs_old, dim = 1 )

            acts = acts.long()

            probs = dis.gather( 1, acts )
            probs_old = dis_old.gather( 1, acts )

            # dis = Categorical( probs )
            # dis_old = Categorical( probs_old )
            ratio = probs / ( probs_old + 1e-8 )            

            # log_prob     = dis.log_prob( acts ).unsqueeze(1)
            # log_prob_old = dis_old.log_prob( acts ).unsqueeze(1)

            ent = -( dis.log() * dis ).sum(-1)


        # Normalize the advantage
        advs = (advs - advs.mean()) / (advs.std() + 1e-8)
        surrogate_loss_pre_clip = ratio * advs
        surrogate_loss_clip = torch.clamp(ratio, 
                        1.0 - self.clip_para,
                        1.0 + self.clip_para) * advs

        # print("ratio min:{} max:{}".format(ratio.detach().min().item(), ratio.detach().max().item()))

        surrogate_loss = -torch.mean(torch.min(surrogate_loss_clip, surrogate_loss_pre_clip))

        policy_loss = surrogate_loss - self.entropy_para * ent.mean()

        criterion = nn.MSELoss()
        critic_loss = criterion( values, est_rs )
        # print("Critic Loss:{}".format(critic_loss.item()))

        self.writer.add_scalar( "Training/Critic_Loss", critic_loss.item(), self.step_count )
        loss = policy_loss + self.value_loss_coeff * critic_loss

        loss.backward()

        self.optim.step()

        self.step_count += 1