Exemple #1
0
    def calc_gradients(self, input_dict, opt_step):
        self.set_train()
        value_preds_batch = input_dict['old_values']
        old_action_log_probs_batch = input_dict['old_logp_actions']
        advantage = input_dict['advantages']
        old_mu_batch = input_dict['mu']
        old_sigma_batch = input_dict['sigma']
        return_batch = input_dict['returns']
        actions_batch = input_dict['actions']
        obs_batch = input_dict['obs']
        obs_batch = self._preproc_obs(obs_batch)

        lr = self.last_lr
        kl = 1.0
        lr_mul = 1.0
        curr_e_clip = lr_mul * self.e_clip

        batch_dict = {
            'is_train': True,
            'prev_actions': actions_batch,
            'obs': obs_batch,
        }

        rnn_masks = None
        if self.is_rnn:
            rnn_masks = input_dict['rnn_masks']
            batch_dict['rnn_states'] = input_dict['rnn_states']
            batch_dict['seq_length'] = self.seq_len

        res_dict = self.model(batch_dict)
        action_log_probs = res_dict['prev_neglogp']
        values = res_dict['value']
        entropy = res_dict['entropy']
        mu = res_dict['mu']
        sigma = res_dict['sigma']

        a_loss = common_losses.actor_loss(old_action_log_probs_batch,
                                          action_log_probs, advantage,
                                          self.ppo, curr_e_clip)

        if self.use_experimental_cv:
            c_loss = common_losses.critic_loss(value_preds_batch, values,
                                               curr_e_clip, return_batch,
                                               self.clip_value)
        else:
            if self.has_central_value:
                c_loss = torch.zeros(1, device=self.ppo_device)
            else:
                c_loss = common_losses.critic_loss(value_preds_batch, values,
                                                   curr_e_clip, return_batch,
                                                   self.clip_value)

        b_loss = self.bound_loss(mu)
        losses, sum_mask = torch_ext.apply_masks([
            a_loss.unsqueeze(1), c_loss,
            entropy.unsqueeze(1),
            b_loss.unsqueeze(1)
        ], rnn_masks)
        a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[
            2], losses[3]

        loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef
        for param in self.model.parameters():
            param.grad = None
        loss.backward()

        if self.config['truncate_grads']:
            nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
        if opt_step:
            self.optimizer.step()

        with torch.no_grad():
            reduce_kl = not self.is_rnn
            kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(),
                                          old_mu_batch, old_sigma_batch,
                                          reduce_kl)
            if self.is_rnn:
                kl_dist = (kl_dist * rnn_masks).sum() / sum_mask
            kl_dist = kl_dist.item()

        self.train_result = (a_loss.item(), c_loss.item(), entropy.item(), \
            kl_dist, self.last_lr, lr_mul, \
            mu.detach(), sigma.detach(), b_loss.item())
Exemple #2
0
    def calc_gradients(self, input_dict):
        value_preds_batch = input_dict['old_values']
        old_action_log_probs_batch = input_dict['old_logp_actions']
        advantage = input_dict['advantages']
        return_batch = input_dict['returns']
        actions_batch = input_dict['actions']
        obs_batch = input_dict['obs']
        obs_batch = self._preproc_obs(obs_batch)
        lr = self.last_lr
        kl = 1.0
        lr_mul = 1.0
        curr_e_clip = lr_mul * self.e_clip

        batch_dict = {
            'is_train': True,
            'prev_actions': actions_batch,
            'obs': obs_batch,
        }
        if self.use_action_masks:
            batch_dict['action_masks'] = input_dict['action_masks']
        rnn_masks = None
        if self.is_rnn:
            rnn_masks = input_dict['rnn_masks']
            batch_dict['rnn_states'] = input_dict['rnn_states']
            batch_dict['seq_length'] = self.seq_len

        with torch.cuda.amp.autocast(enabled=self.mixed_precision):
            res_dict = self.model(batch_dict)
            action_log_probs = res_dict['prev_neglogp']
            values = res_dict['value']
            entropy = res_dict['entropy']
            a_loss = common_losses.actor_loss(old_action_log_probs_batch,
                                              action_log_probs, advantage,
                                              self.ppo, curr_e_clip)

            if self.use_experimental_cv:
                c_loss = common_losses.critic_loss(value_preds_batch, values,
                                                   curr_e_clip, return_batch,
                                                   self.clip_value)
            else:
                if self.has_central_value:
                    c_loss = torch.zeros(1, device=self.ppo_device)
                else:
                    c_loss = common_losses.critic_loss(value_preds_batch,
                                                       values, curr_e_clip,
                                                       return_batch,
                                                       self.clip_value)

            losses, sum_mask = torch_ext.apply_masks(
                [a_loss.unsqueeze(1), c_loss,
                 entropy.unsqueeze(1)], rnn_masks)
            a_loss, c_loss, entropy = losses[0], losses[1], losses[2]
            loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef

            if self.multi_gpu:
                self.optimizer.zero_grad()
            else:
                for param in self.model.parameters():
                    param.grad = None

        self.scaler.scale(loss).backward()
        if self.config['truncate_grads']:
            if self.multi_gpu:
                self.optimizer.synchronize()
                self.scaler.unscale_(self.optimizer)
                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.grad_norm)
                with self.optimizer.skip_synchronize():
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
            else:
                self.scaler.unscale_(self.optimizer)
                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.grad_norm)
                self.scaler.step(self.optimizer)
                self.scaler.update()
        else:
            self.scaler.step(self.optimizer)
            self.scaler.update()

        with torch.no_grad():
            kl_dist = 0.5 * (
                (old_action_log_probs_batch - action_log_probs)**2)
            if self.is_rnn:
                kl_dist = (kl_dist *
                           rnn_masks).sum() / rnn_masks.numel()  # / sum_mask
            else:
                kl_dist = kl_dist.mean()

        self.train_result = (a_loss, c_loss, entropy, kl_dist, self.last_lr,
                             lr_mul)
Exemple #3
0
    def calc_gradients(self, input_dict):
        value_preds_batch = input_dict['old_values']
        old_action_log_probs_batch = input_dict['old_logp_actions']
        advantage = input_dict['advantages']
        old_mu_batch = input_dict['mu']
        old_sigma_batch = input_dict['sigma']
        return_batch = input_dict['returns']
        actions_batch = input_dict['actions']
        obs_batch = input_dict['obs']
        obs_batch = self._preproc_obs(obs_batch)

        lr_mul = 1.0
        curr_e_clip = lr_mul * self.e_clip

        batch_dict = {
            'is_train': True,
            'prev_actions': actions_batch, 
            'obs' : obs_batch,
        }

        rnn_masks = None
        if self.is_rnn:
            rnn_masks = input_dict['rnn_masks']
            batch_dict['rnn_states'] = input_dict['rnn_states']
            batch_dict['seq_length'] = self.seq_len
            
        with torch.cuda.amp.autocast(enabled=self.mixed_precision):
            res_dict = self.model(batch_dict)
            action_log_probs = res_dict['prev_neglogp']
            values = res_dict['values']
            entropy = res_dict['entropy']
            mu = res_dict['mus']
            sigma = res_dict['sigmas']

            if self.ewma_ppo:
                ewma_dict = self.ewma_model(batch_dict)
                proxy_neglogp = ewma_dict['prev_neglogp']
                a_loss = common_losses.decoupled_actor_loss(old_action_log_probs_batch, action_log_probs, proxy_neglogp, advantage, curr_e_clip)
                old_action_log_probs_batch = proxy_neglogp # to get right statistic later
                old_mu_batch = ewma_dict['mus']
                old_sigma_batch = ewma_dict['sigmas']
            else:
                a_loss = common_losses.actor_loss(old_action_log_probs_batch, action_log_probs, advantage, self.ppo, curr_e_clip)

            if self.has_value_loss:
                c_loss = common_losses.critic_loss(value_preds_batch, values, curr_e_clip, return_batch, self.clip_value)
            else:
                c_loss = torch.zeros(1, device=self.ppo_device)
            if self.bound_loss_type == 'regularisation':
                b_loss = self.reg_loss(mu)
            elif self.bound_loss_type == 'bound':
                b_loss = self.bound_loss(mu)
            else:
                b_loss = torch.zeros(1, device=self.ppo_device)
            losses, sum_mask = torch_ext.apply_masks([a_loss.unsqueeze(1), c_loss, entropy.unsqueeze(1), b_loss.unsqueeze(1)], rnn_masks)
            a_loss, c_loss, entropy, b_loss = losses[0], losses[1], losses[2], losses[3]

            loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef + b_loss * self.bounds_loss_coef
            
            if self.multi_gpu:
                self.optimizer.zero_grad()
            else:
                for param in self.model.parameters():
                    param.grad = None

        self.scaler.scale(loss).backward()
        #TODO: Refactor this ugliest code of they year
        if self.truncate_grads:
            if self.multi_gpu:
                self.optimizer.synchronize()
                self.scaler.unscale_(self.optimizer)
                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
                with self.optimizer.skip_synchronize():
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
            else:
                self.scaler.unscale_(self.optimizer)
                nn.utils.clip_grad_norm_(self.model.parameters(), self.grad_norm)
                self.scaler.step(self.optimizer)
                self.scaler.update()    
        else:
            self.scaler.step(self.optimizer)
            self.scaler.update()

        with torch.no_grad():
            reduce_kl = rnn_masks is None
            kl_dist = torch_ext.policy_kl(mu.detach(), sigma.detach(), old_mu_batch, old_sigma_batch, reduce_kl)
            if rnn_masks is not None:
                kl_dist = (kl_dist * rnn_masks).sum() / rnn_masks.numel()  #/ sum_mask

        if self.ewma_ppo:
            self.ewma_model.update()                    

        self.diagnostics.mini_batch(self,
        {
            'values' : value_preds_batch,
            'returns' : return_batch,
            'new_neglogp' : action_log_probs,
            'old_neglogp' : old_action_log_probs_batch,
            'masks' : rnn_masks
        }, curr_e_clip, 0)      

        self.train_result = (a_loss, c_loss, entropy, \
            kl_dist, self.last_lr, lr_mul, \
            mu.detach(), sigma.detach(), b_loss)
    def calc_gradients(self, input_dict):
        value_preds_batch = input_dict['old_values']
        old_action_log_probs_batch = input_dict['old_logp_actions']
        advantage = input_dict['advantages']
        return_batch = input_dict['returns']
        actions_batch = input_dict['actions']
        obs_batch = input_dict['obs']
        obs_batch = self._preproc_obs(obs_batch)
        lr_mul = 1.0
        curr_e_clip = lr_mul * self.e_clip

        batch_dict = {
            'is_train': True,
            'prev_actions': actions_batch,
            'obs': obs_batch,
        }
        if self.use_action_masks:
            batch_dict['action_masks'] = input_dict['action_masks']
        rnn_masks = None
        if self.is_rnn:
            rnn_masks = input_dict['rnn_masks']
            batch_dict['rnn_states'] = input_dict['rnn_states']
            batch_dict['seq_length'] = self.seq_len
            batch_dict['bptt_len'] = self.bptt_len
            batch_dict['dones'] = input_dict['dones']

        with torch.cuda.amp.autocast(enabled=self.mixed_precision):
            res_dict = self.model(batch_dict)
            action_log_probs = res_dict['prev_neglogp']
            values = res_dict['values']
            entropy = res_dict['entropy']
            if self.ewma_ppo:
                ewma_dict = self.ewma_model(batch_dict)
                proxy_neglogp = ewma_dict['prev_neglogp']
                a_loss = common_losses.decoupled_actor_loss(
                    old_action_log_probs_batch, action_log_probs,
                    proxy_neglogp, advantage, curr_e_clip)
                old_action_log_probs_batch = proxy_neglogp  # to get right statistic later
            else:
                a_loss = common_losses.actor_loss(old_action_log_probs_batch,
                                                  action_log_probs, advantage,
                                                  self.ppo, curr_e_clip)

            if self.has_value_loss:
                c_loss = common_losses.critic_loss(value_preds_batch, values,
                                                   curr_e_clip, return_batch,
                                                   self.clip_value)
            else:
                c_loss = torch.zeros(1, device=self.ppo_device)

            losses, sum_mask = torch_ext.apply_masks(
                [a_loss.unsqueeze(1), c_loss,
                 entropy.unsqueeze(1)], rnn_masks)
            a_loss, c_loss, entropy = losses[0], losses[1], losses[2]
            loss = a_loss + 0.5 * c_loss * self.critic_coef - entropy * self.entropy_coef

            if self.multi_gpu:
                self.optimizer.zero_grad()
            else:
                for param in self.model.parameters():
                    param.grad = None

        self.scaler.scale(loss).backward()
        if self.truncate_grads:
            if self.multi_gpu:
                self.optimizer.synchronize()
                self.scaler.unscale_(self.optimizer)
                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.grad_norm)
                with self.optimizer.skip_synchronize():
                    self.scaler.step(self.optimizer)
                    self.scaler.update()
            else:
                self.scaler.unscale_(self.optimizer)
                nn.utils.clip_grad_norm_(self.model.parameters(),
                                         self.grad_norm)
                self.scaler.step(self.optimizer)
                self.scaler.update()
        else:
            self.scaler.step(self.optimizer)
            self.scaler.update()

        with torch.no_grad():
            kl_dist = 0.5 * (
                (old_action_log_probs_batch - action_log_probs)**2)
            if rnn_masks is not None:
                kl_dist = (kl_dist *
                           rnn_masks).sum() / rnn_masks.numel()  # / sum_mask
            else:
                kl_dist = kl_dist.mean()
        if self.has_phasic_policy_gradients:
            c_loss = self.ppg_aux_loss.train_value(self, input_dict)

        if self.ewma_ppo:
            self.ewma_model.update()

        self.diagnostics.mini_batch(
            self, {
                'values': value_preds_batch,
                'returns': return_batch,
                'new_neglogp': action_log_probs,
                'old_neglogp': old_action_log_probs_batch,
                'masks': rnn_masks
            }, curr_e_clip, 0)

        self.train_result = (a_loss, c_loss, entropy, kl_dist, self.last_lr,
                             lr_mul)