def calc_policy_loss(self, batch, pdparams, advs): '''Calculate the actor's policy loss''' action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams) actions = batch['actions'] if self.body.env.is_venv: actions = math_util.venv_unpack(actions) log_probs = action_pd.log_prob(actions) policy_loss = -self.policy_loss_coef * (log_probs * advs).mean() if self.entropy_coef_spec: entropy = action_pd.entropy().mean() self.body.mean_entropy = entropy # update logging variable policy_loss += (-self.body.entropy_coef * entropy) logger.debug(f'Actor policy loss: {policy_loss:g}') return policy_loss
def calc_nstep_advs_v_targets(self, batch, v_preds): ''' Calculate N-step returns, and advs = nstep_rets - v_preds, v_targets = nstep_rets See n-step advantage under http://rail.eecs.berkeley.edu/deeprlcourse-fa17/f17docs/lecture_5_actor_critic_pdf.pdf ''' next_states = batch['next_states'][-1] if not self.body.env.is_venv: next_states = next_states.unsqueeze(dim=0) with torch.no_grad(): next_v_pred = self.calc_v(next_states, use_cache=False) v_preds = v_preds.detach() # adv does not accumulate grad if self.body.env.is_venv: v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs) nstep_rets = math_util.calc_nstep_returns(batch['rewards'], batch['dones'], next_v_pred, self.gamma, self.num_step_returns) advs = nstep_rets - v_preds v_targets = nstep_rets if self.body.env.is_venv: advs = math_util.venv_unpack(advs) v_targets = math_util.venv_unpack(v_targets) logger.debug(f'advs: {advs}\nv_targets: {v_targets}') return advs, v_targets
def calc_q_loss(self, batch): '''Compute the Q value loss using predicted and target Q values from the appropriate networks''' states = batch['states'] next_states = batch['next_states'] if self.body.env.is_venv: states = math_util.venv_unpack(states) next_states = math_util.venv_unpack(next_states) q_preds = self.net(states) with torch.no_grad(): next_q_preds = self.net(next_states) if self.body.env.is_venv: q_preds = math_util.venv_pack(q_preds, self.body.env.num_envs) next_q_preds = math_util.venv_pack(next_q_preds, self.body.env.num_envs) act_q_preds = q_preds.gather( -1, batch['actions'].long().unsqueeze(-1)).squeeze(-1) act_next_q_preds = next_q_preds.gather( -1, batch['next_actions'].long().unsqueeze(-1)).squeeze(-1) act_q_targets = batch['rewards'] + self.gamma * ( 1 - batch['dones']) * act_next_q_preds logger.debug( f'act_q_preds: {act_q_preds}\nact_q_targets: {act_q_targets}') q_loss = self.net.loss_fn(act_q_preds, act_q_targets) return q_loss
def calc_pdparam_v(self, batch): """ Args: batch: Returns: pdaram: logits for discrete prob from one nets. v_pred: value_predict from critic """ '''Efficiently forward to get pdparam and v by batch for loss computation''' states = batch['states'] if self.body.env.is_venv: states = math_util.venv_unpack(states) pdparam = self.calc_pdparam(states) v_pred = self.calc_v( states) # uses self.v_pred from calc_pdparam if self.shared return pdparam, v_pred
def calc_sil_policy_val_loss(self, batch, pdparams): ''' Calculate the SIL policy losses for actor and critic sil_policy_loss = -log_prob * max(R - v_pred, 0) sil_val_loss = (max(R - v_pred, 0)^2) / 2 This is called on a randomly-sample batch from experience replay ''' v_preds = self.calc_v(batch['states'], use_cache=False) rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma) clipped_advs = torch.clamp(rets - v_preds, min=0.0) action_pd = policy_util.init_action_pd(self.body.ActionPD, pdparams) actions = batch['actions'] if self.body.env.is_venv: actions = math_util.venv_unpack(actions) log_probs = action_pd.log_prob(actions) sil_policy_loss = -self.sil_policy_loss_coef * (log_probs * clipped_advs).mean() sil_val_loss = self.sil_val_loss_coef * clipped_advs.pow(2).mean() / 2 logger.debug(f'SIL actor policy loss: {sil_policy_loss:g}') logger.debug(f'SIL critic value loss: {sil_val_loss:g}') return sil_policy_loss, sil_val_loss
def train(self): if util.in_eval_lab_modes(): return np.nan clock = self.body.env.clock if self.to_train == 1: net_util.copy(self.net, self.old_net) # update old net batch = self.sample() clock.set_batch_size(len(batch)) _pdparams, v_preds = self.calc_pdparam_v(batch) advs, v_targets = self.calc_advs_v_targets(batch, v_preds) # piggy back on batch, but remember to not pack or unpack batch['advs'], batch['v_targets'] = advs, v_targets if self.body.env.is_venv: # unpack if venv for minibatch sampling for k, v in batch.items(): if k not in ('advs', 'v_targets'): batch[k] = math_util.venv_unpack(v) total_loss = torch.tensor(0.0) for _ in range(self.training_epoch): minibatches = util.split_minibatch(batch, self.minibatch_size) for minibatch in minibatches: if self.body.env.is_venv: # re-pack to restore proper shape for k, v in minibatch.items(): if k not in ('advs', 'v_targets'): minibatch[k] = math_util.venv_pack( v, self.body.env.num_envs) advs, v_targets = minibatch['advs'], minibatch['v_targets'] pdparams, v_preds = self.calc_pdparam_v(minibatch) policy_loss = self.calc_policy_loss( minibatch, pdparams, advs) # from actor val_loss = self.calc_val_loss(v_preds, v_targets) # from critic if self.shared: # shared network loss = policy_loss + val_loss self.net.train_step(loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) else: self.net.train_step(policy_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) self.critic_net.train_step( val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net) loss = policy_loss + val_loss total_loss += loss loss = total_loss / self.training_epoch / len(minibatches) # reset self.to_train = 0 logger.debug( f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}' ) return loss.item() else: return np.nan
def train(self): # torch.save(self.net.state_dict(), './reward_model/policy_pretrain.mdl') # raise ValueError("policy pretrain stops") if util.in_eval_lab_modes(): return np.nan clock = self.body.env.clock if self.body.env.clock.epi > 700: self.pretrain_finished = True # torch.save(self.discriminator.state_dict(), './reward_model/airl_pretrain.mdl') # raise ValueError("pretrain stops here") if self.to_train == 1: net_util.copy(self.net, self.old_net) # update old net batch = self.sample() if self.reward_type == 'OFFGAN': batch = self.replace_reward_batch(batch) # if self.reward_type =='DISC': # batch = self.fetch_disc_reward(batch) # if self.reward_type =='AIRL': # batch = self.fetch_airl_reward(batch) # if self.reward_type == 'OFFGAN_update': # batch = self.fetch_offgan_reward(batch) clock.set_batch_size(len(batch)) _pdparams, v_preds = self.calc_pdparam_v(batch) advs, v_targets = self.calc_advs_v_targets(batch, v_preds) # piggy back on batch, but remember to not pack or unpack batch['advs'], batch['v_targets'] = advs, v_targets if self.body.env.is_venv: # unpack if venv for minibatch sampling for k, v in batch.items(): if k not in ('advs', 'v_targets'): batch[k] = math_util.venv_unpack(v) total_loss = torch.tensor(0.0) for _ in range(self.training_epoch): minibatches = util.split_minibatch(batch, self.minibatch_size) # if not self.pretrain_finished or not self.policy_training_flag: # break for minibatch in minibatches: if self.body.env.is_venv: # re-pack to restore proper shape for k, v in minibatch.items(): if k not in ('advs', 'v_targets'): minibatch[k] = math_util.venv_pack( v, self.body.env.num_envs) advs, v_targets = minibatch['advs'], minibatch['v_targets'] pdparams, v_preds = self.calc_pdparam_v(minibatch) policy_loss = self.calc_policy_loss( minibatch, pdparams, advs) # from actor val_loss = self.calc_val_loss(v_preds, v_targets) # from critic if self.shared: # shared network loss = policy_loss + val_loss self.net.train_step(loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) else: # pretrain_finished = false -> policy keep fixed, updating value net and disc if not self.pretrain_finished: self.critic_net.train_step( val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net) loss = val_loss if self.pretrain_finished and self.policy_training_flag: self.net.train_step(policy_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) self.critic_net.train_step( val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net) loss = policy_loss + val_loss total_loss += loss loss = total_loss / self.training_epoch / len(minibatches) if not self.pretrain_finished: logger.info( "warmup Value net, epi: {}, frame: {}, loss: {}".format( clock.epi, clock.frame, loss)) # reset self.to_train = 0 self.policy_training_flag = False logger.debug( f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}' ) return loss.item() else: return np.nan