def optimize_agent(self, itr, samples): """ Train the agent, for multiple epochs over minibatches taken from the input samples. Organizes agent inputs from the training data, and moves them to device (e.g. GPU) up front, so that minibatches are formed within device, without further data transfer. """ recurrent = self.agent.recurrent agent_inputs = AgentInputs( # Move inputs to device once, index there. observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, ) agent_inputs = buffer_to(agent_inputs, device=self.agent.device) if hasattr(self.agent, "update_obs_rms"): self.agent.update_obs_rms(agent_inputs.observation) return_, advantage, valid = self.process_returns(samples, self.normalize_rewards) loss_inputs = LossInputs( # So can slice all. agent_inputs=agent_inputs, action=samples.agent.action, return_=return_, advantage=advantage, valid=valid, old_dist_info=samples.agent.agent_info.dist_info, ) if recurrent: # Leave in [B,N,H] for slicing to minibatches. init_rnn_state = samples.agent.agent_info.prev_rnn_state[0] # T=0. T, B = samples.env.reward.shape[:2] opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) # If recurrent, use whole trajectories, only shuffle B; else shuffle all. batch_size = B if self.agent.recurrent else T * B mb_size = batch_size // self.minibatches for _ in range(self.epochs): for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True): T_idxs = slice(None) if recurrent else idxs % T B_idxs = idxs if recurrent else idxs // T self.optimizer.zero_grad() rnn_state = init_rnn_state[B_idxs] if recurrent else None # NOTE: if not recurrent, will lose leading T dim, should be OK. loss, entropy, perplexity = self.loss( *loss_inputs[T_idxs, B_idxs], rnn_state) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.parameters(), self.clip_grad_norm) self.optimizer.step() opt_info.loss.append(loss.item()) opt_info.gradNorm.append(grad_norm) opt_info.entropy.append(entropy.item()) opt_info.perplexity.append(perplexity.item()) self.update_counter += 1 if self.linear_lr_schedule: self.lr_scheduler.step() self.ratio_clip = self._ratio_clip * (self.n_itr - itr) / self.n_itr # if self.vae_lr_scheduler: # self.vae_lr_scheduler.step() return opt_info
def optimize_agent(self, itr, samples): recurrent = self.agent.recurrent agent_inputs = AgentInputs( # Move inputs to device once, index there. observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, ) agent_inputs = buffer_to(agent_inputs, device=self.agent.device) return_, advantage, valid = self.process_returns(samples) loss_inputs = LossInputs( # So can slice all. agent_inputs=agent_inputs, action=samples.agent.action, return_=return_, advantage=advantage, valid=valid, old_dist_info=samples.agent.agent_info.dist_info, ) if recurrent: # Leave in [B,N,H] for slicing to minibatches. init_rnn_state = samples.agent.agent_info.prev_rnn_state[0] # T=0. T, B = samples.env.reward.shape[:2] opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) # If recurrent, use whole trajectories, only shuffle B; else shuffle all. batch_size = B if self.agent.recurrent else T * B mb_size = batch_size // self.minibatches for _ in range(self.epochs): for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True): T_idxs = slice(None) if recurrent else idxs % T B_idxs = idxs if recurrent else idxs // T rnn_state = init_rnn_state[B_idxs] if recurrent else None # NOTE: if not recurrent, will lose leading T dim, should be OK. pi_loss, value_loss, entropy, perplexity = self.loss( *loss_inputs[T_idxs, B_idxs], rnn_state) self.optimizer.zero_grad() pi_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.parameters(), self.clip_grad_norm) self.optimizer.step() self.v_optimizer.zero_grad() value_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.parameters(), self.clip_grad_norm) self.v_optimizer.step() opt_info.loss.append(pi_loss.item()) opt_info.gradNorm.append(grad_norm) opt_info.entropy.append(entropy.item()) opt_info.perplexity.append(perplexity.item()) self.update_counter += 1 if self.linear_lr_schedule: self.lr_scheduler.step() self.ratio_clip = self._ratio_clip * (self.n_itr - itr) / self.n_itr return opt_info
def curiosity_step(self, curiosity_type, *args): curiosity_model = self.model.module.curiosity_model if isinstance( self.model, torch.nn.parallel.DistributedDataParallel ) else self.model.curiosity_model curiosity_step_minibatches = self.model_kwargs[ 'curiosity_step_kwargs']['curiosity_step_minibatches'] T, B = args[0].shape[:2] # either observation or next_observation batch_size = B mb_size = batch_size // curiosity_step_minibatches if curiosity_type in {'icm', 'micm', 'disagreement'}: observation, next_observation, actions = args actions = self.distribution.to_onehot(actions) curiosity_agent_inputs = IcmAgentCuriosityStepInputs( observation=observation, next_observation=next_observation, actions=actions) curiosity_agent_inputs = buffer_to(curiosity_agent_inputs, device=self.device) agent_curiosity_info = IcmInfo() elif curiosity_type == 'ndigo': observation, prev_actions, actions = args actions = self.distribution.to_onehot(actions) prev_actions = self.distribution.to_onehot(prev_actions) curiosity_agent_inputs = NdigoAgentCuriosityStepInputs( observations=observation, prev_actions=prev_actions, actions=actions) curiosity_agent_inputs = buffer_to(curiosity_agent_inputs, device=self.device) agent_curiosity_info = NdigoInfo(prev_gru_state=None) elif curiosity_type == 'rnd': next_observation, done = args curiosity_agent_inputs = RndAgentCuriosityStepInputs( next_observation=next_observation, done=done) curiosity_agent_inputs = buffer_to(curiosity_agent_inputs, device=self.device) agent_curiosity_info = RndInfo() # Need to split the intrinsic reward predictions to several minibatches -- otherwise, we will run out of GPU memory r_ints = [] for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=False): T_idxs = slice(None) B_idxs = idxs mb_r_int = curiosity_model.compute_bonus( *curiosity_agent_inputs[slice(None), B_idxs]) r_ints.append(mb_r_int) r_int = torch.cat(r_ints, dim=1) r_int, agent_curiosity_info = buffer_to((r_int, agent_curiosity_info), device="cpu") return AgentCuriosityStep(r_int=r_int, agent_curiosity_info=agent_curiosity_info)
def optimize_agent(self, itr, samples): recurrent = self.agent.recurrent agent_inputs = AgentInputs(observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward) agent_inputs = buffer_to(agent_inputs, device=self.agent.device) return_, advantage, valid = self.process_returns(samples) loss_inputs = LossInputs( agent_inputs=agent_inputs, action=samples.agent.action, return_=return_, advantage=advantage, valid=valid, old_dist_info=samples.agent.agent_info.dist_info, ) if recurrent: init_rnn_state = samples.agent.agent_info.prev_rnn_state[0] T, B = samples.env.reward.shape[:2] opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) batch_size = B if self.agent.recurrent else T * B mb_size = batch_size // self.minibatches for _ in range(self.epochs): for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True): T_idxs = slice(None) if recurrent else idxs % T B_idxs = idxs if recurrent else idxs // T self.optimizer.zero_grad() rnn_state = init_rnn_state[B_idxs] if recurrent else None loss, entropy, perplexity = self.loss( *loss_inputs[T_idxs, B_idxs], rnn_state) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.parameters(), self.clip_grad_norm) self.optimizer.step() opt_info.loss.append(loss.item()) opt_info.gradNorm.append(grad_norm) opt_info.entropy.append(entropy.item()) opt_info.perplexity.append(perplexity.item()) if self.linear_lr_schedule: self.lr_scheduler.step() self.ratio_clip = self._ratio_clip * (self.n_itr - itr) / self.n_itr return opt_info
def optimize_agent(self, itr, samples=None, sampler_itr=None): """ Train the agent, for multiple epochs over minibatches taken from the input samples. Organizes agent inputs from the training data, and moves them to device (e.g. GPU) up front, so that minibatches are formed within device, without further data transfer. """ opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) agent_inputs = AgentInputs( # Move inputs to device once, index there. observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, ) agent_inputs = buffer_to(agent_inputs, device=self.agent.device) init_rnn_states = buffer_to(samples.agent.agent_info.prev_rnn_state[0], device=self.agent.device) T, B = samples.env.reward.shape[:2] mb_size = B // self.minibatches for _ in range(self.epochs): for idxs in iterate_mb_idxs(B, mb_size, shuffle=True): self.optimizer.zero_grad() init_rnn_state = buffer_method(init_rnn_states[idxs], "transpose", 0, 1) dist_info, value, _ = self.agent(*agent_inputs[:, idxs], init_rnn_state) loss, opt_info = self.process_returns( samples.env.reward[:, idxs], done=samples.env.done[:, idxs], value_prediction=value.cpu(), action=samples.agent.action[:, idxs], dist_info=dist_info, old_dist_info=samples.agent.agent_info.dist_info[:, idxs], opt_info=opt_info) loss.backward() self.optimizer.step() self.clamp_lagrange_multipliers() opt_info.loss.append(loss.item()) self.update_counter += 1 return opt_info
def optimize_agent(self, itr, samples): """ Override to provide additional flexibility in what enters the combined_loss function. """ opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) recurrent = self.agent.recurrent agent_inputs = AgentInputs( observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, ) agent_inputs = buffer_to(agent_inputs, device=self.agent.device) if hasattr(self.agent, "update_obs_rms"): self.agent.update_obs_rms(agent_inputs.observation) # Process extrinsic returns and advantages ext_rew, done, ext_val, ext_bv = (samples.env.reward, samples.env.done, samples.agent.agent_info.ext_value, samples.agent.bootstrap_value) done = done.type(ext_rew.dtype) if self.ext_rew_clip: # Clip extrinsic reward is specified rew_min, rew_max = self.ext_rew_clip ext_rew = ext_rew.clamp(rew_min, rew_max) ext_return, ext_adv, valid = self.process_extrinsic_returns( ext_rew, done, ext_val, ext_bv) # Gather next observations, or fill with dummy placeholder (current obs) # Note the agent decides what it extracts and uses as input to its model, # so the dummy tensor scenario will have no effect next_obs = samples.env.next_observation if "next_observation" in samples.env else samples.env.observation # First call to bonus model, generates intrinsic rewards for samples batch # [T, B] leading dims are flattened, and the resulting returns are unflattened batch_shape = samples.env.observation.shape[:2] bonus_model_inputs = self.agent.extract_bonus_inputs( observation=samples.env.observation.flatten(end_dim=1), next_observation=next_obs.flatten( end_dim=1 ), # May be same as observation (dummy placeholder) if algo set next_obs=False action=samples.agent.action.flatten(end_dim=1)) self.agent.set_norm_update( True ) # Bonus model will update any normalization models where applicable with torch.no_grad(): int_rew, _ = self.agent.bonus_call(bonus_model_inputs) int_rew = int_rew.view(batch_shape) # Process intrinsic returns and advantages (updating intrinsic reward normalization model, if applicable) int_val, int_bv = samples.agent.agent_info.int_value, samples.agent.int_bootstrap_value int_return, int_adv = self.process_intrinsic_returns( int_rew, int_val, int_bv) # Avoid repeating any norm updates on same data in subsequent loss forward calls self.agent.set_norm_update(False) # Add front-processed optimizer data to logging buffer # Flattened to match elsewhere, though the ultimate statistics summarize over all dims anyway opt_info.extrinsicValue.extend(ext_val.flatten().tolist()) opt_info.intrinsicValue.extend(int_val.flatten().tolist()) opt_info.intrinsicReward.extend(int_rew.flatten().tolist()) opt_info.discountedIntrinsicReturn.extend( int_return.flatten().tolist()) opt_info.meanObsRmsModel.extend( self.agent.bonus_model.obs_rms.mean.flatten().tolist()) opt_info.varObsRmsModel.extend( self.agent.bonus_model.obs_rms.var.flatten().tolist()) opt_info.meanIntRetRmsModel.extend( self.agent.bonus_model.int_rff_rms.mean.flatten().tolist()) opt_info.varIntRetRmsModel.extend( self.agent.bonus_model.int_rff_rms.var.flatten().tolist()) loss_inputs = LossInputs( # So can slice all. agent_inputs=agent_inputs, action=samples.agent.action, next_obs=next_obs, ext_return=ext_return, ext_adv=ext_adv, int_return=int_return, int_adv=int_adv, valid=valid, old_dist_info=samples.agent.agent_info.dist_info) if recurrent: # Leave in [B,N,H] for slicing to minibatches. init_rnn_state = samples.agent.agent_info.prev_rnn_state[0] # T=0. # If recurrent, use whole trajectories, only shuffle B; else shuffle all. T, B = samples.env.reward.shape[:2] batch_size = B if self.agent.recurrent else T * B mb_size = batch_size // self.minibatches for _ in range(self.epochs): for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True): T_idxs = slice(None) if recurrent else idxs % T B_idxs = idxs if recurrent else idxs // T self.optimizer.zero_grad() rnn_state = init_rnn_state[B_idxs] if recurrent else None # NOTE: if not recurrent, will lose leading T dim, should be OK. # Combined loss produces single loss for both actor and bonus model loss, entropy, perplexity, pi_loss, value_loss, entropy_loss, bonus_loss = \ self.combined_loss(*loss_inputs[T_idxs, B_idxs], rnn_state) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.parameters(), self.clip_grad_norm) self.optimizer.step() opt_info.loss.append(loss.item()) opt_info.policyLoss.append(pi_loss.item()) opt_info.valueLoss.append(value_loss.item()) opt_info.entropyLoss.append(entropy_loss.item()) opt_info.bonusLoss.append(bonus_loss.item()) opt_info.gradNorm.append(grad_norm) opt_info.entropy.append(entropy.item()) opt_info.perplexity.append(perplexity.item()) self.update_counter += 1 if self.linear_lr_schedule: self.lr_scheduler.step() self.ratio_clip = self._ratio_clip * (self.n_itr - itr) / self.n_itr return opt_info
def compute_beta_kl(self, loss_inputs, init_rnn_state, batch_size, mb_size, T): """Ratio of KL divergences from reward-only vs cost-only updates.""" self.agent.beta_r_model.load_state_dict( strip_ddp_state_dict(self.agent.model.state_dict())) self.agent.beta_c_model.load_state_dict( strip_ddp_state_dict(self.agent.model.state_dict())) self.beta_r_optimizer.load_state_dict(self.optimizer.state_dict()) self.beta_c_optimizer.load_state_dict(self.optimizer.state_dict()) recurrent = self.agent.recurrent for _ in range(self.beta_kl_epochs): for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=batch_size > mb_size): T_idxs = slice(None) if recurrent else idxs % T B_idxs = idxs if recurrent else idxs // T rnn_state = init_rnn_state[B_idxs] if recurrent else None self.beta_r_optimizer.zero_grad() self.beta_c_optimizer.zero_grad() beta_r_loss, beta_c_loss = self.beta_kl_losses( *loss_inputs[T_idxs, B_idxs], rnn_state) beta_r_loss.backward() _ = torch.nn.utils.clip_grad_norm_( self.agent.beta_r_model.parameters(), self.clip_grad_norm) self.beta_r_optimizer.step() beta_c_loss.backward() _ = torch.nn.utils.clip_grad_norm_( self.agent.beta_c_model.parameters(), self.clip_grad_norm) self.beta_c_optimizer.step() if init_rnn_state is not None: # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") with torch.no_grad(): r_dist_info, c_dist_info = self.agent.beta_dist_infos( *loss_inputs.agent_inputs, init_rnn_state) else: with torch.no_grad(): r_dist_info, c_dist_info = self.agent.beta_dist_infos( *loss_inputs.agent_inputs, init_rnn_state) dist = self.agent.distribution beta_r_KL = dist.mean_kl(new_dist_info=r_dist_info, old_dist_info=loss_inputs.old_dist_info, valid=loss_inputs.valid) beta_c_KL = dist.mean_kl(new_dist_info=c_dist_info, old_dist_info=loss_inputs.old_dist_info, valid=loss_inputs.valid) if self._ddp: beta_KLs = torch.stack([beta_r_KL, beta_c_KL]) beta_KLs = beta_KLs.to(self.agent.device) torch.distributed.all_reduce(beta_KLs) beta_KLs = beta_KLs.to("cpu") beta_KLs /= torch.distributed.get_world_size() beta_r_KL, beta_c_KL = beta_KLs[0], beta_KLs[1] raw_beta_KL = float(beta_r_KL / max(beta_c_KL, 1e-8)) return raw_beta_KL, float(beta_r_KL), float(beta_c_KL)
def optimize_agent(self, itr, samples): recurrent = self.agent.recurrent agent_inputs = AgentInputs( # Move inputs to device once, index there. observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, ) agent_inputs = buffer_to(agent_inputs, device=self.agent.device) # return_, advantage, valid = self.process_returns(samples) (return_, advantage, valid, c_return, c_advantage, ep_cost_avg) = self.process_returns(itr, samples) loss_inputs = LossInputs( # So can slice all. agent_inputs=agent_inputs, action=samples.agent.action, return_=return_, advantage=advantage, valid=valid, old_dist_info=samples.agent.agent_info.dist_info, c_return=c_return, # Can be None. c_advantage=c_advantage, ) opt_info = OptInfoCost(*([] for _ in range(len(OptInfoCost._fields)))) if (self.step_cost_limit_itr is not None and self.step_cost_limit_itr == itr): self.cost_limit = self.step_cost_limit_value opt_info.costLimit.append(self.cost_limit) # PID update here: delta = float(ep_cost_avg - self.cost_limit) # ep_cost_avg: tensor self.pid_i = max(0., self.pid_i + delta * self.pid_Ki) if self.diff_norm: self.pid_i = max(0., min(1., self.pid_i)) a_p = self.pid_delta_p_ema_alpha self._delta_p *= a_p self._delta_p += (1 - a_p) * delta a_d = self.pid_delta_d_ema_alpha self._cost_d *= a_d self._cost_d += (1 - a_d) * float(ep_cost_avg) pid_d = max(0., self._cost_d - self.cost_ds[0]) pid_o = (self.pid_Kp * self._delta_p + self.pid_i + self.pid_Kd * pid_d) self.cost_penalty = max(0., pid_o) if self.diff_norm: self.cost_penalty = min(1., self.cost_penalty) if not (self.diff_norm or self.sum_norm): self.cost_penalty = min(self.cost_penalty, self.penalty_max) self.cost_ds.append(self._cost_d) opt_info.pid_i.append(self.pid_i) opt_info.pid_p.append(self._delta_p) opt_info.pid_d.append(pid_d) opt_info.pid_o.append(pid_o) opt_info.costPenalty.append(self.cost_penalty) if hasattr(self.agent, "update_obs_rms"): self.agent.update_obs_rms(agent_inputs.observation) if itr == 0: return opt_info # Sacrifice the first batch to get obs stats. if recurrent: # Leave in [B,N,H] for slicing to minibatches. init_rnn_state = samples.agent.agent_info.prev_rnn_state[0] # T=0. T, B = samples.env.reward.shape[:2] # If recurrent, use whole trajectories, only shuffle B; else shuffle all. batch_size = B if self.agent.recurrent else T * B mb_size = batch_size // self.minibatches if self.use_beta_kl or self.record_beta_kl: raw_beta_kl, beta_r_kl, beta_c_kl = self.compute_beta_kl( loss_inputs, init_rnn_state, batch_size, mb_size, T) beta_KL = min(self.beta_max, max(self.beta_min, raw_beta_kl)) self._beta_kl *= self.beta_ema_alpha self._beta_kl += (1 - self.beta_ema_alpha) * beta_KL opt_info.betaKlRaw.append(raw_beta_kl) opt_info.betaKL.append(self._beta_kl) opt_info.betaKlR.append(beta_r_kl) opt_info.betaKlC.append(beta_c_kl) # print("raw_beta_kl: ", raw_beta_kl) # print("self._beta_kl: ", self._beta_kl, "\n\n") if self.use_beta_grad or self.record_beta_grad: raw_beta_grad = self.compute_beta_grad(loss_inputs, init_rnn_state) beta_grad = min(self.beta_max, max(self.beta_min, raw_beta_grad)) self._beta_grad *= self.beta_ema_alpha self._beta_grad += (1 - self.beta_ema_alpha) * beta_grad opt_info.betaGradRaw.append(raw_beta_grad) opt_info.betaGrad.append(self._beta_grad) for _ in range(self.epochs): for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True): T_idxs = slice(None) if recurrent else idxs % T B_idxs = idxs if recurrent else idxs // T self.optimizer.zero_grad() rnn_state = init_rnn_state[B_idxs] if recurrent else None # NOTE: if not recurrent, will lose leading T dim, should be OK. loss, entropy, perplexity, value_errors, abs_value_errors = self.loss( *loss_inputs[T_idxs, B_idxs], rnn_state) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.parameters(), self.clip_grad_norm) self.optimizer.step() opt_info.loss.append(loss.item()) opt_info.gradNorm.append(grad_norm) opt_info.entropy.append(entropy.item()) opt_info.perplexity.append(perplexity.item()) opt_info.valueError.extend(value_errors[0][::10].numpy()) opt_info.cvalueError.extend(value_errors[1][::10].numpy()) opt_info.valueAbsError.extend( abs_value_errors[0][::10].numpy()) opt_info.cvalueAbsError.extend( abs_value_errors[1][::10].numpy()) self.update_counter += 1 if self.linear_lr_schedule: self.lr_scheduler.step() self.ratio_clip = self._ratio_clip * \ (self.n_itr - itr) / self.n_itr return opt_info
def optimize_agent(self, itr, samples): """ Train the agent, for multiple epochs over minibatches taken from the input samples. Organizes agent inputs from the training data, and moves them to device (e.g. GPU) up front, so that minibatches are formed within device, without further data transfer. """ recurrent = self.agent.recurrent agent_inputs = AgentInputs( # Move inputs to device once, index there. observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, ) agent_inputs = buffer_to(agent_inputs, device=self.agent.device) if hasattr(self.agent, "update_obs_rms"): self.agent.update_obs_rms(agent_inputs.observation) if self.agent.dual_model: return_, advantage, valid, return_int_, advantage_int = self.process_returns( samples) else: return_, advantage, valid = self.process_returns(samples) if self.curiosity_type in {'icm', 'micm', 'disagreement'}: agent_curiosity_inputs = IcmAgentCuriosityInputs( observation=samples.env.observation.clone(), next_observation=samples.env.next_observation.clone(), action=samples.agent.action.clone(), valid=valid) agent_curiosity_inputs = buffer_to(agent_curiosity_inputs, device=self.agent.device) elif self.curiosity_type == 'ndigo': agent_curiosity_inputs = NdigoAgentCuriosityInputs( observation=samples.env.observation.clone(), prev_actions=samples.agent.prev_action.clone(), actions=samples.agent.action.clone(), valid=valid) agent_curiosity_inputs = buffer_to(agent_curiosity_inputs, device=self.agent.device) elif self.curiosity_type == 'rnd': agent_curiosity_inputs = RndAgentCuriosityInputs( next_observation=samples.env.next_observation.clone(), valid=valid) agent_curiosity_inputs = buffer_to(agent_curiosity_inputs, device=self.agent.device) elif self.curiosity_type == 'none': agent_curiosity_inputs = None if self.policy_loss_type == 'dual': loss_inputs = LossInputsTwin( # So can slice all. agent_inputs=agent_inputs, agent_curiosity_inputs=agent_curiosity_inputs, action=samples.agent.action, return_=return_, advantage=advantage, valid=valid, old_dist_info=samples.agent.agent_info.dist_info, return_int_=return_int_, advantage_int=advantage_int, old_dist_int_info=samples.agent.agent_info.dist_int_info, ) else: loss_inputs = LossInputs( # So can slice all. agent_inputs=agent_inputs, agent_curiosity_inputs=agent_curiosity_inputs, action=samples.agent.action, return_=return_, advantage=advantage, valid=valid, old_dist_info=samples.agent.agent_info.dist_info, ) if recurrent: # Leave in [B,N,H] for slicing to minibatches. init_rnn_state = samples.agent.agent_info.prev_rnn_state[0] # T=0. if self.agent.dual_model: init_int_rnn_state = samples.agent.agent_info.prev_int_rnn_state[ 0] # T=0. T, B = samples.env.reward.shape[:2] if self.policy_loss_type == 'dual': opt_info = OptInfoTwin(*([] for _ in range(len(OptInfoTwin._fields)))) else: opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) # If recurrent, use whole trajectories, only shuffle B; else shuffle all. batch_size = B if self.agent.recurrent else T * B mb_size = batch_size // self.minibatches for _ in range(self.epochs): for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True): T_idxs = slice(None) if recurrent else idxs % T B_idxs = idxs if recurrent else idxs // T self.optimizer.zero_grad() rnn_state = init_rnn_state[B_idxs] if recurrent else None # NOTE: if not recurrent, will lose leading T dim, should be OK. if self.policy_loss_type == 'dual': int_rnn_state = init_int_rnn_state[ B_idxs] if recurrent else None loss_inputs_batch = loss_inputs[T_idxs, B_idxs] loss, pi_loss, value_loss, entropy_loss, entropy, perplexity, \ int_pi_loss, int_value_loss, int_entropy_loss, int_entropy, int_perplexity, \ curiosity_losses = self.loss( agent_inputs=loss_inputs_batch.agent_inputs, agent_curiosity_inputs=loss_inputs_batch.agent_curiosity_inputs, action=loss_inputs_batch.action, return_=loss_inputs_batch.return_, advantage=loss_inputs_batch.advantage, valid=loss_inputs_batch.valid, old_dist_info=loss_inputs_batch.old_dist_info, return_int_=loss_inputs_batch.return_int_, advantage_int=loss_inputs_batch.advantage_int, old_dist_int_info=loss_inputs_batch.old_dist_int_info, init_rnn_state=rnn_state, init_int_rnn_state=int_rnn_state) else: loss, pi_loss, value_loss, entropy_loss, entropy, perplexity, curiosity_losses = self.loss( *loss_inputs[T_idxs, B_idxs], rnn_state) loss.backward() count = 0 grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.parameters(), self.clip_grad_norm) self.optimizer.step() # Tensorboard summaries opt_info.loss.append(loss.item()) opt_info.pi_loss.append(pi_loss.item()) opt_info.value_loss.append(value_loss.item()) opt_info.entropy_loss.append(entropy_loss.item()) if self.policy_loss_type == 'dual': opt_info.int_pi_loss.append(int_pi_loss.item()) opt_info.int_value_loss.append(int_value_loss.item()) opt_info.int_entropy_loss.append(int_entropy_loss.item()) if self.curiosity_type in {'icm', 'micm'}: inv_loss, forward_loss = curiosity_losses opt_info.inv_loss.append(inv_loss.item()) opt_info.forward_loss.append(forward_loss.item()) opt_info.intrinsic_rewards.append( np.mean(self.intrinsic_rewards)) opt_info.extint_ratio.append(np.mean(self.extint_ratio)) elif self.curiosity_type == 'disagreement': forward_loss = curiosity_losses opt_info.forward_loss.append(forward_loss.item()) opt_info.intrinsic_rewards.append( np.mean(self.intrinsic_rewards)) opt_info.extint_ratio.append(np.mean(self.extint_ratio)) elif self.curiosity_type == 'ndigo': forward_loss = curiosity_losses opt_info.forward_loss.append(forward_loss.item()) opt_info.intrinsic_rewards.append( np.mean(self.intrinsic_rewards)) opt_info.extint_ratio.append(np.mean(self.extint_ratio)) elif self.curiosity_type == 'rnd': forward_loss = curiosity_losses opt_info.forward_loss.append(forward_loss.item()) opt_info.intrinsic_rewards.append( np.mean(self.intrinsic_rewards)) opt_info.extint_ratio.append(np.mean(self.extint_ratio)) if self.normalize_reward: opt_info.reward_total_std.append(self.reward_rms.var**0.5) if self.policy_loss_type == 'dual': opt_info.int_reward_total_std.append( self.int_reward_rms.var**0.5) opt_info.entropy.append(entropy.item()) opt_info.perplexity.append(perplexity.item()) if self.policy_loss_type == 'dual': opt_info.int_entropy.append(int_entropy.item()) opt_info.int_perplexity.append(int_perplexity.item()) self.update_counter += 1 opt_info.return_.append( torch.mean(return_.detach()).detach().clone().item()) opt_info.advantage.append( torch.mean(advantage.detach()).detach().clone().item()) opt_info.valpred.append( torch.mean(samples.agent.agent_info.value.detach()).detach().clone( ).item()) if self.policy_loss_type == 'dual': opt_info.return_int_.append( torch.mean(return_int_.detach()).detach().clone().item()) opt_info.advantage_int.append( torch.mean(advantage_int.detach()).detach().clone().item()) opt_info.int_valpred.append( torch.mean(samples.agent.agent_info.int_value.detach()).detach( ).clone().item()) if self.linear_lr_schedule: self.lr_scheduler.step() self.ratio_clip = self._ratio_clip * (self.n_itr - itr) / self.n_itr layer_info = dict( ) # empty dict to store model layer weights for tensorboard visualizations return opt_info, layer_info
def compute_minibatch_gradients(self, samples): """ Train the agent, for multiple epochs over minibatches taken from the input samples. Organizes agent inputs from the training data, and moves them to device (e.g. GPU) up front, so that minibatches are formed within device, without further data transfer. """ recurrent = self.agent.recurrent agent_inputs = AgentInputs( # Move inputs to device once, index there. observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, ) agent_inputs = buffer_to(agent_inputs, device=self.agent.device) return_, advantage, valid, value, reward, pre_reward = self.process_returns( samples) loss_inputs = LossInputs( # So can slice all. agent_inputs=agent_inputs, action=samples.agent.action, return_=return_, old_value=value, advantage=advantage, valid=valid, old_dist_info=samples.agent.agent_info.dist_info, ) if recurrent: # Leave in [B,N,H] for slicing to minibatches. init_rnn_state = samples.agent.agent_info.prev_rnn_state[0] # T=0. T, B = samples.env.reward.shape[:2] # If recurrent, use whole trajectories, only shuffle B; else shuffle all. batch_size = T * B if self.agent.recurrent else T * B policy_gradients = [] value_gradients = [] all_value_diffs = [] all_ratios = [] for idxs in iterate_mb_idxs(batch_size, batch_size, shuffle=True): T_idxs = slice(None) if recurrent else idxs % T B_idxs = idxs if recurrent else idxs // T self.optimizer.zero_grad() rnn_state = init_rnn_state[B_idxs] if recurrent else None # NOTE: if not recurrent, will lose leading T dim, should be OK. loss, pi_loss, value_loss, entropy, perplexity, value_diffs, ratio = self.loss( *loss_inputs[T_idxs, B_idxs], rnn_state) loss.backward() # for i, p in enumerate(self.agent.parameters()): # print(i, p.grad) # print([(i, p.shape) for i, p in enumerate(self.agent.parameters())]) # first 7 is policy, last 6 is value network params = [ p.grad.data.cpu().numpy().flatten() for p in self.agent.parameters() ] pg = np.concatenate(params[:7]).ravel() vg = np.concatenate(params[7:]).ravel() policy_gradients.append(pg) value_gradients.append(vg) # gradient = np.concatenate([p.grad.data.cpu().numpy().flatten() for p in self.agent.parameters()]).ravel() # gradients.append(gradient) all_value_diffs.extend(value_diffs.detach().numpy().flatten()) all_ratios.extend(ratio.detach().numpy().flatten()) self.update_counter += 1 return policy_gradients, value_gradients, all_value_diffs, all_ratios, reward, pre_reward