def calc_gae_advs_v_targets(self, batch, v_preds): ''' Calculate GAE, and advs = GAE, v_targets = advs + v_preds See GAE from Schulman et al. https://arxiv.org/pdf/1506.02438.pdf ''' next_states = batch['next_states'][-1] if not self.body.env.is_venv: next_states = next_states.unsqueeze(dim=0) with torch.no_grad(): next_v_pred = self.calc_v(next_states, use_cache=False) v_preds = v_preds.detach() # adv does not accumulate grad if self.body.env.is_venv: v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs) next_v_pred = next_v_pred.unsqueeze(dim=0) v_preds_all = torch.cat((v_preds, next_v_pred), dim=0) advs = math_util.calc_gaes(batch['rewards'], batch['dones'], v_preds_all, self.gamma, self.lam) v_targets = advs + v_preds advs = math_util.standardize( advs) # standardize only for advs, not v_targets if self.body.env.is_venv: advs = math_util.venv_unpack(advs) v_targets = math_util.venv_unpack(v_targets) logger.debug(f'advs: {advs}\nv_targets: {v_targets}') return advs, v_targets
def calc_ret_advs_v_targets(self, batch, v_preds): '''Calculate plain returns, and advs = rets - v_preds, v_targets = rets''' v_preds = v_preds.detach() # adv does not accumulate grad if self.body.env.is_venv: v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs) rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma) advs = rets - v_preds v_targets = rets if self.body.env.is_venv: advs = math_util.venv_unpack(advs) v_targets = math_util.venv_unpack(v_targets) logger.debug(f'advs: {advs}\nv_targets: {v_targets}') return advs, v_targets
def calc_q_loss(self, batch): '''Compute the Q value loss using predicted and target Q values from the appropriate networks''' states = batch['states'] next_states = batch['next_states'] if self.body.env.is_venv: states = math_util.venv_unpack(states) next_states = math_util.venv_unpack(next_states) q_preds = self.net(states) with torch.no_grad(): next_q_preds = self.net(next_states) if self.body.env.is_venv: q_preds = math_util.venv_pack(q_preds, self.body.env.num_envs) next_q_preds = math_util.venv_pack(next_q_preds, self.body.env.num_envs) act_q_preds = q_preds.gather( -1, batch['actions'].long().unsqueeze(-1)).squeeze(-1) act_next_q_preds = next_q_preds.gather( -1, batch['next_actions'].long().unsqueeze(-1)).squeeze(-1) act_q_targets = batch['rewards'] + self.gamma * ( 1 - batch['dones']) * act_next_q_preds logger.debug( f'act_q_preds: {act_q_preds}\nact_q_targets: {act_q_targets}') q_loss = self.net.loss_fn(act_q_preds, act_q_targets) return q_loss
def calc_ret_advs_v_targets(self, batch, v_preds): """ Args: batch: v_preds: Returns: advs(difference between prediction and original reward), v_targets, original reward. """ '''Calculate plain returns, and advs = rets - v_preds, v_targets = rets''' v_preds = v_preds.detach() # adv does not accumulate grad if self.body.env.is_venv: v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs) rets = math_util.calc_returns(batch['rewards'], batch['dones'], self.gamma) advs = rets - v_preds v_targets = rets if self.body.env.is_venv: advs = math_util.venv_unpack(advs) v_targets = math_util.venv_unpack(v_targets) logger.debug(f'advs: {advs}\nv_targets: {v_targets}') return advs, v_targets
def calc_nstep_advs_v_targets(self, batch, v_preds): ''' Calculate N-step returns, and advs = nstep_rets - v_preds, v_targets = nstep_rets See n-step advantage under http://rail.eecs.berkeley.edu/deeprlcourse-fa17/f17docs/lecture_5_actor_critic_pdf.pdf ''' next_states = batch['next_states'][-1] if not self.body.env.is_venv: next_states = next_states.unsqueeze(dim=0) with torch.no_grad(): next_v_pred = self.calc_v(next_states, use_cache=False) v_preds = v_preds.detach() # adv does not accumulate grad if self.body.env.is_venv: v_preds = math_util.venv_pack(v_preds, self.body.env.num_envs) nstep_rets = math_util.calc_nstep_returns(batch['rewards'], batch['dones'], next_v_pred, self.gamma, self.num_step_returns) advs = nstep_rets - v_preds v_targets = nstep_rets if self.body.env.is_venv: advs = math_util.venv_unpack(advs) v_targets = math_util.venv_unpack(v_targets) logger.debug(f'advs: {advs}\nv_targets: {v_targets}') return advs, v_targets
def train(self): if util.in_eval_lab_modes(): return np.nan clock = self.body.env.clock if self.to_train == 1: net_util.copy(self.net, self.old_net) # update old net batch = self.sample() clock.set_batch_size(len(batch)) _pdparams, v_preds = self.calc_pdparam_v(batch) advs, v_targets = self.calc_advs_v_targets(batch, v_preds) # piggy back on batch, but remember to not pack or unpack batch['advs'], batch['v_targets'] = advs, v_targets if self.body.env.is_venv: # unpack if venv for minibatch sampling for k, v in batch.items(): if k not in ('advs', 'v_targets'): batch[k] = math_util.venv_unpack(v) total_loss = torch.tensor(0.0) for _ in range(self.training_epoch): minibatches = util.split_minibatch(batch, self.minibatch_size) for minibatch in minibatches: if self.body.env.is_venv: # re-pack to restore proper shape for k, v in minibatch.items(): if k not in ('advs', 'v_targets'): minibatch[k] = math_util.venv_pack( v, self.body.env.num_envs) advs, v_targets = minibatch['advs'], minibatch['v_targets'] pdparams, v_preds = self.calc_pdparam_v(minibatch) policy_loss = self.calc_policy_loss( minibatch, pdparams, advs) # from actor val_loss = self.calc_val_loss(v_preds, v_targets) # from critic if self.shared: # shared network loss = policy_loss + val_loss self.net.train_step(loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) else: self.net.train_step(policy_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) self.critic_net.train_step( val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net) loss = policy_loss + val_loss total_loss += loss loss = total_loss / self.training_epoch / len(minibatches) # reset self.to_train = 0 logger.debug( f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}' ) return loss.item() else: return np.nan
def train(self): # torch.save(self.net.state_dict(), './reward_model/policy_pretrain.mdl') # raise ValueError("policy pretrain stops") if util.in_eval_lab_modes(): return np.nan clock = self.body.env.clock if self.body.env.clock.epi > 700: self.pretrain_finished = True # torch.save(self.discriminator.state_dict(), './reward_model/airl_pretrain.mdl') # raise ValueError("pretrain stops here") if self.to_train == 1: net_util.copy(self.net, self.old_net) # update old net batch = self.sample() if self.reward_type == 'OFFGAN': batch = self.replace_reward_batch(batch) # if self.reward_type =='DISC': # batch = self.fetch_disc_reward(batch) # if self.reward_type =='AIRL': # batch = self.fetch_airl_reward(batch) # if self.reward_type == 'OFFGAN_update': # batch = self.fetch_offgan_reward(batch) clock.set_batch_size(len(batch)) _pdparams, v_preds = self.calc_pdparam_v(batch) advs, v_targets = self.calc_advs_v_targets(batch, v_preds) # piggy back on batch, but remember to not pack or unpack batch['advs'], batch['v_targets'] = advs, v_targets if self.body.env.is_venv: # unpack if venv for minibatch sampling for k, v in batch.items(): if k not in ('advs', 'v_targets'): batch[k] = math_util.venv_unpack(v) total_loss = torch.tensor(0.0) for _ in range(self.training_epoch): minibatches = util.split_minibatch(batch, self.minibatch_size) # if not self.pretrain_finished or not self.policy_training_flag: # break for minibatch in minibatches: if self.body.env.is_venv: # re-pack to restore proper shape for k, v in minibatch.items(): if k not in ('advs', 'v_targets'): minibatch[k] = math_util.venv_pack( v, self.body.env.num_envs) advs, v_targets = minibatch['advs'], minibatch['v_targets'] pdparams, v_preds = self.calc_pdparam_v(minibatch) policy_loss = self.calc_policy_loss( minibatch, pdparams, advs) # from actor val_loss = self.calc_val_loss(v_preds, v_targets) # from critic if self.shared: # shared network loss = policy_loss + val_loss self.net.train_step(loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) else: # pretrain_finished = false -> policy keep fixed, updating value net and disc if not self.pretrain_finished: self.critic_net.train_step( val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net) loss = val_loss if self.pretrain_finished and self.policy_training_flag: self.net.train_step(policy_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) self.critic_net.train_step( val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net) loss = policy_loss + val_loss total_loss += loss loss = total_loss / self.training_epoch / len(minibatches) if not self.pretrain_finished: logger.info( "warmup Value net, epi: {}, frame: {}, loss: {}".format( clock.epi, clock.frame, loss)) # reset self.to_train = 0 self.policy_training_flag = False logger.debug( f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}' ) return loss.item() else: return np.nan