示例#1
0
 def calc_gae_advs_v_targets(self, batch, v_preds):
     '''
     Calculate GAE, and advs = GAE, v_targets = advs + v_preds
     See GAE from Schulman et al. https://arxiv.org/pdf/1506.02438.pdf
     '''
     next_states = batch['next_states'][-1]
     if not self.body.env.is_venv:
         next_states = next_states.unsqueeze(dim=0)
     with torch.no_grad():
         next_v_pred = self.calc_v(next_states, use_cache=False)
     v_preds = v_preds.detach()  # adv does not accumulate grad
     if self.body.env.is_venv:
         v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs)
         next_v_pred = next_v_pred.unsqueeze(dim=0)
     v_preds_all = torch.cat((v_preds, next_v_pred), dim=0)
     advs = math_util.calc_gaes(batch['rewards'], batch['dones'],
                                v_preds_all, self.gamma, self.lam)
     v_targets = advs + v_preds
     advs = math_util.standardize(
         advs)  # standardize only for advs, not v_targets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
         v_targets = math_util.venv_unpack(v_targets)
     logger.debug(f'advs: {advs}\nv_targets: {v_targets}')
     return advs, v_targets
 def calc_ret_advs_v_targets(self, batch, v_preds):
     '''Calculate plain returns, and advs = rets - v_preds, v_targets = rets'''
     v_preds = v_preds.detach()  # adv does not accumulate grad
     if self.body.env.is_venv:
         v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs)
     rets = math_util.calc_returns(batch['rewards'], batch['dones'],
                                   self.gamma)
     advs = rets - v_preds
     v_targets = rets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
         v_targets = math_util.venv_unpack(v_targets)
     logger.debug(f'advs: {advs}\nv_targets: {v_targets}')
     return advs, v_targets
示例#3
0
 def calc_q_loss(self, batch):
     '''Compute the Q value loss using predicted and target Q values from the appropriate networks'''
     states = batch['states']
     next_states = batch['next_states']
     if self.body.env.is_venv:
         states = math_util.venv_unpack(states)
         next_states = math_util.venv_unpack(next_states)
     q_preds = self.net(states)
     with torch.no_grad():
         next_q_preds = self.net(next_states)
     if self.body.env.is_venv:
         q_preds = math_util.venv_pack(q_preds, self.body.env.num_envs)
         next_q_preds = math_util.venv_pack(next_q_preds,
                                            self.body.env.num_envs)
     act_q_preds = q_preds.gather(
         -1, batch['actions'].long().unsqueeze(-1)).squeeze(-1)
     act_next_q_preds = next_q_preds.gather(
         -1, batch['next_actions'].long().unsqueeze(-1)).squeeze(-1)
     act_q_targets = batch['rewards'] + self.gamma * (
         1 - batch['dones']) * act_next_q_preds
     logger.debug(
         f'act_q_preds: {act_q_preds}\nact_q_targets: {act_q_targets}')
     q_loss = self.net.loss_fn(act_q_preds, act_q_targets)
     return q_loss
示例#4
0
 def calc_ret_advs_v_targets(self, batch, v_preds):
     """
     Args:
         batch:
         v_preds:
     Returns: advs(difference between prediction and original reward), v_targets, original reward.
     """
     '''Calculate plain returns, and advs = rets - v_preds, v_targets = rets'''
     v_preds = v_preds.detach()  # adv does not accumulate grad
     if self.body.env.is_venv:
         v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs)
     rets = math_util.calc_returns(batch['rewards'], batch['dones'],
                                   self.gamma)
     advs = rets - v_preds
     v_targets = rets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
         v_targets = math_util.venv_unpack(v_targets)
     logger.debug(f'advs: {advs}\nv_targets: {v_targets}')
     return advs, v_targets
示例#5
0
 def calc_nstep_advs_v_targets(self, batch, v_preds):
     '''
     Calculate N-step returns, and advs = nstep_rets - v_preds, v_targets = nstep_rets
     See n-step advantage under http://rail.eecs.berkeley.edu/deeprlcourse-fa17/f17docs/lecture_5_actor_critic_pdf.pdf
     '''
     next_states = batch['next_states'][-1]
     if not self.body.env.is_venv:
         next_states = next_states.unsqueeze(dim=0)
     with torch.no_grad():
         next_v_pred = self.calc_v(next_states, use_cache=False)
     v_preds = v_preds.detach()  # adv does not accumulate grad
     if self.body.env.is_venv:
         v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs)
     nstep_rets = math_util.calc_nstep_returns(batch['rewards'],
                                               batch['dones'], next_v_pred,
                                               self.gamma,
                                               self.num_step_returns)
     advs = nstep_rets - v_preds
     v_targets = nstep_rets
     if self.body.env.is_venv:
         advs = math_util.venv_unpack(advs)
         v_targets = math_util.venv_unpack(v_targets)
     logger.debug(f'advs: {advs}\nv_targets: {v_targets}')
     return advs, v_targets
示例#6
0
 def train(self):
     if util.in_eval_lab_modes():
         return np.nan
     clock = self.body.env.clock
     if self.to_train == 1:
         net_util.copy(self.net, self.old_net)  # update old net
         batch = self.sample()
         clock.set_batch_size(len(batch))
         _pdparams, v_preds = self.calc_pdparam_v(batch)
         advs, v_targets = self.calc_advs_v_targets(batch, v_preds)
         # piggy back on batch, but remember to not pack or unpack
         batch['advs'], batch['v_targets'] = advs, v_targets
         if self.body.env.is_venv:  # unpack if venv for minibatch sampling
             for k, v in batch.items():
                 if k not in ('advs', 'v_targets'):
                     batch[k] = math_util.venv_unpack(v)
         total_loss = torch.tensor(0.0)
         for _ in range(self.training_epoch):
             minibatches = util.split_minibatch(batch, self.minibatch_size)
             for minibatch in minibatches:
                 if self.body.env.is_venv:  # re-pack to restore proper shape
                     for k, v in minibatch.items():
                         if k not in ('advs', 'v_targets'):
                             minibatch[k] = math_util.venv_pack(
                                 v, self.body.env.num_envs)
                 advs, v_targets = minibatch['advs'], minibatch['v_targets']
                 pdparams, v_preds = self.calc_pdparam_v(minibatch)
                 policy_loss = self.calc_policy_loss(
                     minibatch, pdparams, advs)  # from actor
                 val_loss = self.calc_val_loss(v_preds,
                                               v_targets)  # from critic
                 if self.shared:  # shared network
                     loss = policy_loss + val_loss
                     self.net.train_step(loss,
                                         self.optim,
                                         self.lr_scheduler,
                                         clock=clock,
                                         global_net=self.global_net)
                 else:
                     self.net.train_step(policy_loss,
                                         self.optim,
                                         self.lr_scheduler,
                                         clock=clock,
                                         global_net=self.global_net)
                     self.critic_net.train_step(
                         val_loss,
                         self.critic_optim,
                         self.critic_lr_scheduler,
                         clock=clock,
                         global_net=self.global_critic_net)
                     loss = policy_loss + val_loss
                 total_loss += loss
         loss = total_loss / self.training_epoch / len(minibatches)
         # reset
         self.to_train = 0
         logger.debug(
             f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}'
         )
         return loss.item()
     else:
         return np.nan
示例#7
0
    def train(self):
        # torch.save(self.net.state_dict(), './reward_model/policy_pretrain.mdl')
        # raise ValueError("policy pretrain stops")
        if util.in_eval_lab_modes():
            return np.nan
        clock = self.body.env.clock
        if self.body.env.clock.epi > 700:
            self.pretrain_finished = True
            # torch.save(self.discriminator.state_dict(), './reward_model/airl_pretrain.mdl')
            # raise ValueError("pretrain stops here")
        if self.to_train == 1:
            net_util.copy(self.net, self.old_net)  # update old net
            batch = self.sample()
            if self.reward_type == 'OFFGAN':
                batch = self.replace_reward_batch(batch)
            # if self.reward_type =='DISC':
            # batch = self.fetch_disc_reward(batch)
            # if self.reward_type =='AIRL':
            # batch = self.fetch_airl_reward(batch)
            # if self.reward_type == 'OFFGAN_update':
            # batch = self.fetch_offgan_reward(batch)

            clock.set_batch_size(len(batch))
            _pdparams, v_preds = self.calc_pdparam_v(batch)
            advs, v_targets = self.calc_advs_v_targets(batch, v_preds)
            # piggy back on batch, but remember to not pack or unpack
            batch['advs'], batch['v_targets'] = advs, v_targets
            if self.body.env.is_venv:  # unpack if venv for minibatch sampling
                for k, v in batch.items():
                    if k not in ('advs', 'v_targets'):
                        batch[k] = math_util.venv_unpack(v)
            total_loss = torch.tensor(0.0)
            for _ in range(self.training_epoch):
                minibatches = util.split_minibatch(batch, self.minibatch_size)

                # if not self.pretrain_finished or not self.policy_training_flag:
                #     break

                for minibatch in minibatches:
                    if self.body.env.is_venv:  # re-pack to restore proper shape
                        for k, v in minibatch.items():
                            if k not in ('advs', 'v_targets'):
                                minibatch[k] = math_util.venv_pack(
                                    v, self.body.env.num_envs)
                    advs, v_targets = minibatch['advs'], minibatch['v_targets']
                    pdparams, v_preds = self.calc_pdparam_v(minibatch)
                    policy_loss = self.calc_policy_loss(
                        minibatch, pdparams, advs)  # from actor
                    val_loss = self.calc_val_loss(v_preds,
                                                  v_targets)  # from critic
                    if self.shared:  # shared network
                        loss = policy_loss + val_loss
                        self.net.train_step(loss,
                                            self.optim,
                                            self.lr_scheduler,
                                            clock=clock,
                                            global_net=self.global_net)
                    else:
                        # pretrain_finished = false -> policy keep fixed, updating value net and disc
                        if not self.pretrain_finished:
                            self.critic_net.train_step(
                                val_loss,
                                self.critic_optim,
                                self.critic_lr_scheduler,
                                clock=clock,
                                global_net=self.global_critic_net)
                            loss = val_loss
                        if self.pretrain_finished and self.policy_training_flag:
                            self.net.train_step(policy_loss,
                                                self.optim,
                                                self.lr_scheduler,
                                                clock=clock,
                                                global_net=self.global_net)
                            self.critic_net.train_step(
                                val_loss,
                                self.critic_optim,
                                self.critic_lr_scheduler,
                                clock=clock,
                                global_net=self.global_critic_net)
                            loss = policy_loss + val_loss

                    total_loss += loss
            loss = total_loss / self.training_epoch / len(minibatches)
            if not self.pretrain_finished:
                logger.info(
                    "warmup Value net, epi: {}, frame: {}, loss: {}".format(
                        clock.epi, clock.frame, loss))
            # reset
            self.to_train = 0
            self.policy_training_flag = False
            logger.debug(
                f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}'
            )
            return loss.item()
        else:
            return np.nan