def update_nets(self): if util.frame_mod(self.body.env.clock.frame, self.net.update_frequency, self.body.env.num_envs): if self.net.update_type == 'replace': net_util.copy(self.net, self.target_net) elif self.net.update_type == 'polyak': net_util.polyak_update(self.net, self.target_net, self.net.polyak_coef) else: raise ValueError('Unknown net.update_type. Should be "replace" or "polyak". Exiting.')
def train_step(self, loss, optim, lr_scheduler, clock=None, global_net=None): lr_scheduler.step(epoch=ps.get(clock, 'frame')) optim.zero_grad() loss.backward() if self.clip_grad_val is not None: nn.utils.clip_grad_norm_(self.parameters(), self.clip_grad_val) if global_net is not None: net_util.push_global_grads(self, global_net) optim.step() if global_net is not None: net_util.copy(global_net, self) clock.tick('opt_step') return loss
def train(self): if util.in_eval_lab_modes(): return np.nan clock = self.body.env.clock if self.to_train == 1: net_util.copy(self.net, self.old_net) # update old net batch = self.sample() clock.set_batch_size(len(batch)) _pdparams, v_preds = self.calc_pdparam_v(batch) advs, v_targets = self.calc_advs_v_targets(batch, v_preds) # piggy back on batch, but remember to not pack or unpack batch['advs'], batch['v_targets'] = advs, v_targets if self.body.env.is_venv: # unpack if venv for minibatch sampling for k, v in batch.items(): if k not in ('advs', 'v_targets'): batch[k] = math_util.venv_unpack(v) total_loss = torch.tensor(0.0) for _ in range(self.training_epoch): minibatches = util.split_minibatch(batch, self.minibatch_size) for minibatch in minibatches: if self.body.env.is_venv: # re-pack to restore proper shape for k, v in minibatch.items(): if k not in ('advs', 'v_targets'): minibatch[k] = math_util.venv_pack( v, self.body.env.num_envs) advs, v_targets = minibatch['advs'], minibatch['v_targets'] pdparams, v_preds = self.calc_pdparam_v(minibatch) policy_loss = self.calc_policy_loss( minibatch, pdparams, advs) # from actor val_loss = self.calc_val_loss(v_preds, v_targets) # from critic if self.shared: # shared network loss = policy_loss + val_loss self.net.train_step(loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) else: self.net.train_step(policy_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) self.critic_net.train_step( val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net) loss = policy_loss + val_loss total_loss += loss loss = total_loss / self.training_epoch / len(minibatches) # reset self.to_train = 0 logger.debug( f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}' ) return loss.item() else: return np.nan
def train(self): # torch.save(self.net.state_dict(), './reward_model/policy_pretrain.mdl') # raise ValueError("policy pretrain stops") if util.in_eval_lab_modes(): return np.nan clock = self.body.env.clock if self.body.env.clock.epi > 700: self.pretrain_finished = True # torch.save(self.discriminator.state_dict(), './reward_model/airl_pretrain.mdl') # raise ValueError("pretrain stops here") if self.to_train == 1: net_util.copy(self.net, self.old_net) # update old net batch = self.sample() if self.reward_type == 'OFFGAN': batch = self.replace_reward_batch(batch) # if self.reward_type =='DISC': # batch = self.fetch_disc_reward(batch) # if self.reward_type =='AIRL': # batch = self.fetch_airl_reward(batch) # if self.reward_type == 'OFFGAN_update': # batch = self.fetch_offgan_reward(batch) clock.set_batch_size(len(batch)) _pdparams, v_preds = self.calc_pdparam_v(batch) advs, v_targets = self.calc_advs_v_targets(batch, v_preds) # piggy back on batch, but remember to not pack or unpack batch['advs'], batch['v_targets'] = advs, v_targets if self.body.env.is_venv: # unpack if venv for minibatch sampling for k, v in batch.items(): if k not in ('advs', 'v_targets'): batch[k] = math_util.venv_unpack(v) total_loss = torch.tensor(0.0) for _ in range(self.training_epoch): minibatches = util.split_minibatch(batch, self.minibatch_size) # if not self.pretrain_finished or not self.policy_training_flag: # break for minibatch in minibatches: if self.body.env.is_venv: # re-pack to restore proper shape for k, v in minibatch.items(): if k not in ('advs', 'v_targets'): minibatch[k] = math_util.venv_pack( v, self.body.env.num_envs) advs, v_targets = minibatch['advs'], minibatch['v_targets'] pdparams, v_preds = self.calc_pdparam_v(minibatch) policy_loss = self.calc_policy_loss( minibatch, pdparams, advs) # from actor val_loss = self.calc_val_loss(v_preds, v_targets) # from critic if self.shared: # shared network loss = policy_loss + val_loss self.net.train_step(loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) else: # pretrain_finished = false -> policy keep fixed, updating value net and disc if not self.pretrain_finished: self.critic_net.train_step( val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net) loss = val_loss if self.pretrain_finished and self.policy_training_flag: self.net.train_step(policy_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) self.critic_net.train_step( val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net) loss = policy_loss + val_loss total_loss += loss loss = total_loss / self.training_epoch / len(minibatches) if not self.pretrain_finished: logger.info( "warmup Value net, epi: {}, frame: {}, loss: {}".format( clock.epi, clock.frame, loss)) # reset self.to_train = 0 self.policy_training_flag = False logger.debug( f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}' ) return loss.item() else: return np.nan