def optimize_agent(self, itr, samples): """ Train the agent on input samples, by one gradient step. """ if hasattr(self.agent, "update_obs_rms"): # NOTE: suboptimal--obs sent to device here and in agent(*inputs). self.agent.update_obs_rms(samples.env.observation) self.optimizer.zero_grad() loss, pi_loss, value_loss, entropy, perplexity = self.loss(samples) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(self.agent.parameters(), self.clip_grad_norm) self.optimizer.step() opt_info = OptInfo( loss=loss.item(), pi_loss=pi_loss.item(), value_loss=value_loss.item(), gradNorm=grad_norm.clone().detach().item( ), # backwards compatible, entropy=entropy.item(), perplexity=perplexity.item(), ) self.update_counter += 1 if self.linear_lr_schedule: self.lr_scheduler.step() return opt_info
def optimize_agent(self, itr, samples): """ Train the agent, for multiple epochs over minibatches taken from the input samples. Organizes agent inputs from the training data, and moves them to device (e.g. GPU) up front, so that minibatches are formed within device, without further data transfer. """ recurrent = self.agent.recurrent agent_inputs = AgentInputs( # Move inputs to device once, index there. observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, ) agent_inputs = buffer_to(agent_inputs, device=self.agent.device) if hasattr(self.agent, "update_obs_rms"): self.agent.update_obs_rms(agent_inputs.observation) return_, advantage, valid = self.process_returns(samples, self.normalize_rewards) loss_inputs = LossInputs( # So can slice all. agent_inputs=agent_inputs, action=samples.agent.action, return_=return_, advantage=advantage, valid=valid, old_dist_info=samples.agent.agent_info.dist_info, ) if recurrent: # Leave in [B,N,H] for slicing to minibatches. init_rnn_state = samples.agent.agent_info.prev_rnn_state[0] # T=0. T, B = samples.env.reward.shape[:2] opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) # If recurrent, use whole trajectories, only shuffle B; else shuffle all. batch_size = B if self.agent.recurrent else T * B mb_size = batch_size // self.minibatches for _ in range(self.epochs): for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True): T_idxs = slice(None) if recurrent else idxs % T B_idxs = idxs if recurrent else idxs // T self.optimizer.zero_grad() rnn_state = init_rnn_state[B_idxs] if recurrent else None # NOTE: if not recurrent, will lose leading T dim, should be OK. loss, entropy, perplexity = self.loss( *loss_inputs[T_idxs, B_idxs], rnn_state) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.parameters(), self.clip_grad_norm) self.optimizer.step() opt_info.loss.append(loss.item()) opt_info.gradNorm.append(grad_norm) opt_info.entropy.append(entropy.item()) opt_info.perplexity.append(perplexity.item()) self.update_counter += 1 if self.linear_lr_schedule: self.lr_scheduler.step() self.ratio_clip = self._ratio_clip * (self.n_itr - itr) / self.n_itr # if self.vae_lr_scheduler: # self.vae_lr_scheduler.step() return opt_info
def optimize_agent(self, itr, samples): recurrent = self.agent.recurrent agent_inputs = AgentInputs( # Move inputs to device once, index there. observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, ) agent_inputs = buffer_to(agent_inputs, device=self.agent.device) return_, advantage, valid = self.process_returns(samples) loss_inputs = LossInputs( # So can slice all. agent_inputs=agent_inputs, action=samples.agent.action, return_=return_, advantage=advantage, valid=valid, old_dist_info=samples.agent.agent_info.dist_info, ) if recurrent: # Leave in [B,N,H] for slicing to minibatches. init_rnn_state = samples.agent.agent_info.prev_rnn_state[0] # T=0. T, B = samples.env.reward.shape[:2] opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) # If recurrent, use whole trajectories, only shuffle B; else shuffle all. batch_size = B if self.agent.recurrent else T * B mb_size = batch_size // self.minibatches for _ in range(self.epochs): for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True): T_idxs = slice(None) if recurrent else idxs % T B_idxs = idxs if recurrent else idxs // T rnn_state = init_rnn_state[B_idxs] if recurrent else None # NOTE: if not recurrent, will lose leading T dim, should be OK. pi_loss, value_loss, entropy, perplexity = self.loss( *loss_inputs[T_idxs, B_idxs], rnn_state) self.optimizer.zero_grad() pi_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.parameters(), self.clip_grad_norm) self.optimizer.step() self.v_optimizer.zero_grad() value_loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.parameters(), self.clip_grad_norm) self.v_optimizer.step() opt_info.loss.append(pi_loss.item()) opt_info.gradNorm.append(grad_norm) opt_info.entropy.append(entropy.item()) opt_info.perplexity.append(perplexity.item()) self.update_counter += 1 if self.linear_lr_schedule: self.lr_scheduler.step() self.ratio_clip = self._ratio_clip * (self.n_itr - itr) / self.n_itr return opt_info
def optimize_agent(self, itr, samples): self.optimizer.zero_grad() loss, entropy, perplexity = self.loss(samples) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(self.agent.parameters(), self.clip_grad_norm) self.optimizer.step() opt_info = OptInfo( loss=loss.item(), gradNorm=grad_norm, entropy=entropy.item(), perplexity=perplexity.item(), ) return opt_info
def optimize_agent(self, itr, samples): recurrent = self.agent.recurrent agent_inputs = AgentInputs(observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward) agent_inputs = buffer_to(agent_inputs, device=self.agent.device) return_, advantage, valid = self.process_returns(samples) loss_inputs = LossInputs( agent_inputs=agent_inputs, action=samples.agent.action, return_=return_, advantage=advantage, valid=valid, old_dist_info=samples.agent.agent_info.dist_info, ) if recurrent: init_rnn_state = samples.agent.agent_info.prev_rnn_state[0] T, B = samples.env.reward.shape[:2] opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) batch_size = B if self.agent.recurrent else T * B mb_size = batch_size // self.minibatches for _ in range(self.epochs): for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True): T_idxs = slice(None) if recurrent else idxs % T B_idxs = idxs if recurrent else idxs // T self.optimizer.zero_grad() rnn_state = init_rnn_state[B_idxs] if recurrent else None loss, entropy, perplexity = self.loss( *loss_inputs[T_idxs, B_idxs], rnn_state) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.parameters(), self.clip_grad_norm) self.optimizer.step() opt_info.loss.append(loss.item()) opt_info.gradNorm.append(grad_norm) opt_info.entropy.append(entropy.item()) opt_info.perplexity.append(perplexity.item()) if self.linear_lr_schedule: self.lr_scheduler.step() self.ratio_clip = self._ratio_clip * (self.n_itr - itr) / self.n_itr return opt_info
def optimize_agent(self, itr, samples): """ Train the agent on input samples, by one gradient step. """ self.optimizer.zero_grad() loss, entropy, perplexity = self.loss(samples) loss.backward() grad_norm = torch.nn.utils.clip_grad_norm_(self.agent.parameters(), self.clip_grad_norm) self.optimizer.step() opt_info = OptInfo( loss=loss.item(), gradNorm=grad_norm, entropy=entropy.item(), perplexity=perplexity.item(), ) self.update_counter += 1 return opt_info
def optimize_agent(self, itr, samples=None, sampler_itr=None): """ Train the agent, for multiple epochs over minibatches taken from the input samples. Organizes agent inputs from the training data, and moves them to device (e.g. GPU) up front, so that minibatches are formed within device, without further data transfer. """ opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) agent_inputs = AgentInputs( # Move inputs to device once, index there. observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, ) agent_inputs = buffer_to(agent_inputs, device=self.agent.device) init_rnn_states = buffer_to(samples.agent.agent_info.prev_rnn_state[0], device=self.agent.device) T, B = samples.env.reward.shape[:2] mb_size = B // self.minibatches for _ in range(self.epochs): for idxs in iterate_mb_idxs(B, mb_size, shuffle=True): self.optimizer.zero_grad() init_rnn_state = buffer_method(init_rnn_states[idxs], "transpose", 0, 1) dist_info, value, _ = self.agent(*agent_inputs[:, idxs], init_rnn_state) loss, opt_info = self.process_returns( samples.env.reward[:, idxs], done=samples.env.done[:, idxs], value_prediction=value.cpu(), action=samples.agent.action[:, idxs], dist_info=dist_info, old_dist_info=samples.agent.agent_info.dist_info[:, idxs], opt_info=opt_info) loss.backward() self.optimizer.step() self.clamp_lagrange_multipliers() opt_info.loss.append(loss.item()) self.update_counter += 1 return opt_info
def optimize_agent(self, itr, samples): """ Train the agent, for multiple epochs over minibatches taken from the input samples. Organizes agent inputs from the training data, and moves them to device (e.g. GPU) up front, so that minibatches are formed within device, without further data transfer. """ recurrent = self.agent.recurrent agent_inputs = AgentInputs( # Move inputs to device once, index there. observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, ) agent_inputs = buffer_to(agent_inputs, device=self.agent.device) if hasattr(self.agent, "update_obs_rms"): self.agent.update_obs_rms(agent_inputs.observation) if self.agent.dual_model: return_, advantage, valid, return_int_, advantage_int = self.process_returns( samples) else: return_, advantage, valid = self.process_returns(samples) if self.curiosity_type in {'icm', 'micm', 'disagreement'}: agent_curiosity_inputs = IcmAgentCuriosityInputs( observation=samples.env.observation.clone(), next_observation=samples.env.next_observation.clone(), action=samples.agent.action.clone(), valid=valid) agent_curiosity_inputs = buffer_to(agent_curiosity_inputs, device=self.agent.device) elif self.curiosity_type == 'ndigo': agent_curiosity_inputs = NdigoAgentCuriosityInputs( observation=samples.env.observation.clone(), prev_actions=samples.agent.prev_action.clone(), actions=samples.agent.action.clone(), valid=valid) agent_curiosity_inputs = buffer_to(agent_curiosity_inputs, device=self.agent.device) elif self.curiosity_type == 'rnd': agent_curiosity_inputs = RndAgentCuriosityInputs( next_observation=samples.env.next_observation.clone(), valid=valid) agent_curiosity_inputs = buffer_to(agent_curiosity_inputs, device=self.agent.device) elif self.curiosity_type == 'none': agent_curiosity_inputs = None if self.policy_loss_type == 'dual': loss_inputs = LossInputsTwin( # So can slice all. agent_inputs=agent_inputs, agent_curiosity_inputs=agent_curiosity_inputs, action=samples.agent.action, return_=return_, advantage=advantage, valid=valid, old_dist_info=samples.agent.agent_info.dist_info, return_int_=return_int_, advantage_int=advantage_int, old_dist_int_info=samples.agent.agent_info.dist_int_info, ) else: loss_inputs = LossInputs( # So can slice all. agent_inputs=agent_inputs, agent_curiosity_inputs=agent_curiosity_inputs, action=samples.agent.action, return_=return_, advantage=advantage, valid=valid, old_dist_info=samples.agent.agent_info.dist_info, ) if recurrent: # Leave in [B,N,H] for slicing to minibatches. init_rnn_state = samples.agent.agent_info.prev_rnn_state[0] # T=0. if self.agent.dual_model: init_int_rnn_state = samples.agent.agent_info.prev_int_rnn_state[ 0] # T=0. T, B = samples.env.reward.shape[:2] if self.policy_loss_type == 'dual': opt_info = OptInfoTwin(*([] for _ in range(len(OptInfoTwin._fields)))) else: opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) # If recurrent, use whole trajectories, only shuffle B; else shuffle all. batch_size = B if self.agent.recurrent else T * B mb_size = batch_size // self.minibatches for _ in range(self.epochs): for idxs in iterate_mb_idxs(batch_size, mb_size, shuffle=True): T_idxs = slice(None) if recurrent else idxs % T B_idxs = idxs if recurrent else idxs // T self.optimizer.zero_grad() rnn_state = init_rnn_state[B_idxs] if recurrent else None # NOTE: if not recurrent, will lose leading T dim, should be OK. if self.policy_loss_type == 'dual': int_rnn_state = init_int_rnn_state[ B_idxs] if recurrent else None loss_inputs_batch = loss_inputs[T_idxs, B_idxs] loss, pi_loss, value_loss, entropy_loss, entropy, perplexity, \ int_pi_loss, int_value_loss, int_entropy_loss, int_entropy, int_perplexity, \ curiosity_losses = self.loss( agent_inputs=loss_inputs_batch.agent_inputs, agent_curiosity_inputs=loss_inputs_batch.agent_curiosity_inputs, action=loss_inputs_batch.action, return_=loss_inputs_batch.return_, advantage=loss_inputs_batch.advantage, valid=loss_inputs_batch.valid, old_dist_info=loss_inputs_batch.old_dist_info, return_int_=loss_inputs_batch.return_int_, advantage_int=loss_inputs_batch.advantage_int, old_dist_int_info=loss_inputs_batch.old_dist_int_info, init_rnn_state=rnn_state, init_int_rnn_state=int_rnn_state) else: loss, pi_loss, value_loss, entropy_loss, entropy, perplexity, curiosity_losses = self.loss( *loss_inputs[T_idxs, B_idxs], rnn_state) loss.backward() count = 0 grad_norm = torch.nn.utils.clip_grad_norm_( self.agent.parameters(), self.clip_grad_norm) self.optimizer.step() # Tensorboard summaries opt_info.loss.append(loss.item()) opt_info.pi_loss.append(pi_loss.item()) opt_info.value_loss.append(value_loss.item()) opt_info.entropy_loss.append(entropy_loss.item()) if self.policy_loss_type == 'dual': opt_info.int_pi_loss.append(int_pi_loss.item()) opt_info.int_value_loss.append(int_value_loss.item()) opt_info.int_entropy_loss.append(int_entropy_loss.item()) if self.curiosity_type in {'icm', 'micm'}: inv_loss, forward_loss = curiosity_losses opt_info.inv_loss.append(inv_loss.item()) opt_info.forward_loss.append(forward_loss.item()) opt_info.intrinsic_rewards.append( np.mean(self.intrinsic_rewards)) opt_info.extint_ratio.append(np.mean(self.extint_ratio)) elif self.curiosity_type == 'disagreement': forward_loss = curiosity_losses opt_info.forward_loss.append(forward_loss.item()) opt_info.intrinsic_rewards.append( np.mean(self.intrinsic_rewards)) opt_info.extint_ratio.append(np.mean(self.extint_ratio)) elif self.curiosity_type == 'ndigo': forward_loss = curiosity_losses opt_info.forward_loss.append(forward_loss.item()) opt_info.intrinsic_rewards.append( np.mean(self.intrinsic_rewards)) opt_info.extint_ratio.append(np.mean(self.extint_ratio)) elif self.curiosity_type == 'rnd': forward_loss = curiosity_losses opt_info.forward_loss.append(forward_loss.item()) opt_info.intrinsic_rewards.append( np.mean(self.intrinsic_rewards)) opt_info.extint_ratio.append(np.mean(self.extint_ratio)) if self.normalize_reward: opt_info.reward_total_std.append(self.reward_rms.var**0.5) if self.policy_loss_type == 'dual': opt_info.int_reward_total_std.append( self.int_reward_rms.var**0.5) opt_info.entropy.append(entropy.item()) opt_info.perplexity.append(perplexity.item()) if self.policy_loss_type == 'dual': opt_info.int_entropy.append(int_entropy.item()) opt_info.int_perplexity.append(int_perplexity.item()) self.update_counter += 1 opt_info.return_.append( torch.mean(return_.detach()).detach().clone().item()) opt_info.advantage.append( torch.mean(advantage.detach()).detach().clone().item()) opt_info.valpred.append( torch.mean(samples.agent.agent_info.value.detach()).detach().clone( ).item()) if self.policy_loss_type == 'dual': opt_info.return_int_.append( torch.mean(return_int_.detach()).detach().clone().item()) opt_info.advantage_int.append( torch.mean(advantage_int.detach()).detach().clone().item()) opt_info.int_valpred.append( torch.mean(samples.agent.agent_info.int_value.detach()).detach( ).clone().item()) if self.linear_lr_schedule: self.lr_scheduler.step() self.ratio_clip = self._ratio_clip * (self.n_itr - itr) / self.n_itr layer_info = dict( ) # empty dict to store model layer weights for tensorboard visualizations return opt_info, layer_info