def GanReward_Update(self, training_times=1): for _ in range(training_times): batch = self.experience_buffer[-1] minibatches = util.split_minibatch(batch, 64) for fake_batch in minibatches: loss = self.reward_agent.update(fake_batch) self.optim_gandisc.zero_grad() loss.backward() torch.nn.utils.clip_grad_value_(self.reward_agent.discriminator.parameters(), 0.5) self.optim_gandisc.step()
def disc_train(self, training_times=1): for t in range(training_times): # idx = min(t+1, len(self.experience_buffer)) batch = self.experience_buffer[-1] minibatches = util.split_minibatch(batch, 64) for fake_batch in minibatches: self.optim_disc.zero_grad() loss = self.discriminator.disc_train(fake_batch) loss.backward() torch.nn.utils.clip_grad_norm_(self.discriminator.parameters(), 3) self.optim_disc.step()
def airl_train(self, training_times=1): # print("airl training") for t in range(training_times): total_loss = 0 # idx = min(t+1, len(self.experience_buffer)) batch = self.experience_buffer[-1] minibatches = util.split_minibatch(batch, 64) # print("minibatch number: {}".format(len(minibatches))) for fake_batch in minibatches: self.optim_disc.zero_grad() loss = self.discriminator.disc_train(fake_batch) total_loss += loss.item() loss.backward() self.optim_disc.step() for p in self.discriminator.parameters(): p.data.clamp_(-0.1, 0.1) logger.info("airl training loss: {}".format(total_loss/len(minibatches)))
def train(self): if util.in_eval_lab_modes(): return np.nan clock = self.body.env.clock if self.to_train == 1: net_util.copy(self.net, self.old_net) # update old net batch = self.sample() clock.set_batch_size(len(batch)) _pdparams, v_preds = self.calc_pdparam_v(batch) advs, v_targets = self.calc_advs_v_targets(batch, v_preds) # piggy back on batch, but remember to not pack or unpack batch['advs'], batch['v_targets'] = advs, v_targets if self.body.env.is_venv: # unpack if venv for minibatch sampling for k, v in batch.items(): if k not in ('advs', 'v_targets'): batch[k] = math_util.venv_unpack(v) total_loss = torch.tensor(0.0) for _ in range(self.training_epoch): minibatches = util.split_minibatch(batch, self.minibatch_size) for minibatch in minibatches: if self.body.env.is_venv: # re-pack to restore proper shape for k, v in minibatch.items(): if k not in ('advs', 'v_targets'): minibatch[k] = math_util.venv_pack( v, self.body.env.num_envs) advs, v_targets = minibatch['advs'], minibatch['v_targets'] pdparams, v_preds = self.calc_pdparam_v(minibatch) policy_loss = self.calc_policy_loss( minibatch, pdparams, advs) # from actor val_loss = self.calc_val_loss(v_preds, v_targets) # from critic if self.shared: # shared network loss = policy_loss + val_loss self.net.train_step(loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) else: self.net.train_step(policy_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) self.critic_net.train_step( val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net) loss = policy_loss + val_loss total_loss += loss loss = total_loss / self.training_epoch / len(minibatches) # reset self.to_train = 0 logger.debug( f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}' ) return loss.item() else: return np.nan
def train(self): # torch.save(self.net.state_dict(), './reward_model/policy_pretrain.mdl') # raise ValueError("policy pretrain stops") if util.in_eval_lab_modes(): return np.nan clock = self.body.env.clock if self.body.env.clock.epi > 700: self.pretrain_finished = True # torch.save(self.discriminator.state_dict(), './reward_model/airl_pretrain.mdl') # raise ValueError("pretrain stops here") if self.to_train == 1: net_util.copy(self.net, self.old_net) # update old net batch = self.sample() if self.reward_type == 'OFFGAN': batch = self.replace_reward_batch(batch) # if self.reward_type =='DISC': # batch = self.fetch_disc_reward(batch) # if self.reward_type =='AIRL': # batch = self.fetch_airl_reward(batch) # if self.reward_type == 'OFFGAN_update': # batch = self.fetch_offgan_reward(batch) clock.set_batch_size(len(batch)) _pdparams, v_preds = self.calc_pdparam_v(batch) advs, v_targets = self.calc_advs_v_targets(batch, v_preds) # piggy back on batch, but remember to not pack or unpack batch['advs'], batch['v_targets'] = advs, v_targets if self.body.env.is_venv: # unpack if venv for minibatch sampling for k, v in batch.items(): if k not in ('advs', 'v_targets'): batch[k] = math_util.venv_unpack(v) total_loss = torch.tensor(0.0) for _ in range(self.training_epoch): minibatches = util.split_minibatch(batch, self.minibatch_size) # if not self.pretrain_finished or not self.policy_training_flag: # break for minibatch in minibatches: if self.body.env.is_venv: # re-pack to restore proper shape for k, v in minibatch.items(): if k not in ('advs', 'v_targets'): minibatch[k] = math_util.venv_pack( v, self.body.env.num_envs) advs, v_targets = minibatch['advs'], minibatch['v_targets'] pdparams, v_preds = self.calc_pdparam_v(minibatch) policy_loss = self.calc_policy_loss( minibatch, pdparams, advs) # from actor val_loss = self.calc_val_loss(v_preds, v_targets) # from critic if self.shared: # shared network loss = policy_loss + val_loss self.net.train_step(loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) else: # pretrain_finished = false -> policy keep fixed, updating value net and disc if not self.pretrain_finished: self.critic_net.train_step( val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net) loss = val_loss if self.pretrain_finished and self.policy_training_flag: self.net.train_step(policy_loss, self.optim, self.lr_scheduler, clock=clock, global_net=self.global_net) self.critic_net.train_step( val_loss, self.critic_optim, self.critic_lr_scheduler, clock=clock, global_net=self.global_critic_net) loss = policy_loss + val_loss total_loss += loss loss = total_loss / self.training_epoch / len(minibatches) if not self.pretrain_finished: logger.info( "warmup Value net, epi: {}, frame: {}, loss: {}".format( clock.epi, clock.frame, loss)) # reset self.to_train = 0 self.policy_training_flag = False logger.debug( f'Trained {self.name} at epi: {clock.epi}, frame: {clock.frame}, t: {clock.t}, total_reward so far: {self.body.total_reward}, loss: {loss:g}' ) return loss.item() else: return np.nan