def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info, init_rnn_state=None): if init_rnn_state is not None: init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state) else: dist_info, value = self.agent(*agent_inputs) dist = self.agent.distribution lr = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=dist_info) kl = dist.kl(old_dist_info=old_dist_info, new_dist_info=dist_info) if init_rnn_state is not None: raise NotImplementedError else: mean_kl = valid_mean(kl) surr_loss = -valid_mean(lr * advantage) loss = surr_loss entropy = dist.mean_entropy(dist_info, valid) perplexity = dist.mean_perplexity(dist_info, valid) return loss, entropy, perplexity
def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info, init_rnn_state=None): if init_rnn_state is not None: # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state) else: dist_info, value = self.agent(*agent_inputs) dist = self.agent.distribution ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=dist_info) surr_1 = ratio * advantage clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip, 1. + self.ratio_clip) surr_2 = clipped_ratio * advantage surrogate = torch.min(surr_1, surr_2) pi_loss = - valid_mean(surrogate, valid) value_error = 0.5 * (value - return_) ** 2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = - self.entropy_loss_coeff * entropy loss = pi_loss + value_loss + entropy_loss perplexity = dist.mean_perplexity(dist_info, valid) return loss, entropy, perplexity
def loss(self, samples): agent_inputs = AgentInputs( observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, ) if self.agent.recurrent: init_rnn_state = self.samples.agent.agent_info.prev_rnn_state[ 0] # T = 0. # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state) else: dist_info, value = self.agent(*agent_inputs) # TODO: try to compute everyone on device. return_, advantage, valid = self.process_returns(samples) dist = self.agent.distribution logli = dist.log_likelihood(samples.agent.action, dist_info) pi_loss = -valid_mean(logli * advantage, valid) value_error = 0.5 * (value - return_)**2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = -self.entropy_loss_coeff * entropy loss = pi_loss + value_loss + entropy_loss perplexity = dist.mean_perplexity(dist_info, valid) return loss, entropy, perplexity
def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info, old_value, init_rnn_state=None): """ Compute the training loss: policy_loss + value_loss + entropy_loss Policy loss: min(likelhood-ratio * advantage, clip(likelihood_ratio, 1-eps, 1+eps) * advantage) Value loss: 0.5 * (estimated_value - return) ^ 2 Calls the agent to compute forward pass on training data, and uses the ``agent.distribution`` to compute likelihoods and entropies. Valid for feedforward or recurrent agents. """ if init_rnn_state is not None: # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state, device=action.device) else: dist_info, value = self.agent(*agent_inputs, device=action.device) dist = self.agent.distribution # Surrogate policy loss ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=dist_info) surr_1 = ratio * advantage clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip, 1. + self.ratio_clip) surr_2 = clipped_ratio * advantage surrogate = torch.min(surr_1, surr_2) pi_loss = -valid_mean(surrogate, valid) # Surrogate value loss (if doing) if self.clip_vf_loss: v_loss_unclipped = (value - return_)**2 v_clipped = old_value + torch.clamp( value - old_value, -self.ratio_clip, self.ratio_clip) v_loss_clipped = (v_clipped - return_)**2 v_loss_max = torch.max(v_loss_unclipped, v_loss_clipped) value_error = 0.5 * v_loss_max.mean() else: value_error = 0.5 * (value - return_)**2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = -self.entropy_loss_coeff * entropy loss = pi_loss + value_loss + entropy_loss perplexity = dist.mean_perplexity(dist_info, valid) return loss, pi_loss, value_loss, entropy, perplexity
def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info, init_rnn_state=None): """ Compute the training loss: policy_loss + value_loss + entropy_loss Policy loss: min(likelhood-ratio * advantage, clip(likelihood_ratio, 1-eps, 1+eps) * advantage) Value loss: 0.5 * (estimated_value - return) ^ 2 Calls the agent to compute forward pass on training data, and uses the ``agent.distribution`` to compute likelihoods and entropies. Valid for feedforward or recurrent agents. """ if init_rnn_state is not None: # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state) else: if self.agent.both_actions: dist_info, og_dist_info, value, og_value = self.agent( *agent_inputs) else: dist_info, value = self.agent(*agent_inputs) dist = self.agent.distribution ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=dist_info) surr_1 = ratio * advantage clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip, 1. + self.ratio_clip) surr_2 = clipped_ratio * advantage surrogate = torch.min(surr_1, surr_2) pi_loss = -valid_mean(surrogate, valid) value_error = 0.5 * (value - return_)**2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = -self.entropy_loss_coeff * entropy loss = pi_loss + value_loss + entropy_loss if self.similarity_loss: # Try KL next # pi_sim = self.agent.distribution.kl(og_dist_info, dist_info) pi_sim = F.cosine_similarity(dist_info.prob, og_dist_info.prob) value_sim = (value - og_value)**2 loss += -self.similarity_coeff * pi_sim.mean() + 0.5 * value_sim.mean() # loss += self.similarity_coeff * (pi_sim.mean() + value_sim) perplexity = dist.mean_perplexity(dist_info, valid) return loss, entropy, perplexity
def loss( self, agent_inputs, action, return_, advantage, valid, old_dist_info, init_rnn_state=None, ): """ Compute the training loss: policy_loss + value_loss + entropy_loss Policy loss: min(likelhood-ratio * advantage, clip(likelihood_ratio, 1-eps, 1+eps) * advantage) Value loss: 0.5 * (estimated_value - return) ^ 2 Calls the agent to compute forward pass on training data, and uses the ``agent.distribution`` to compute likelihoods and entropies. Valid for feedforward or recurrent agents. """ if init_rnn_state is not None: # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state) else: dist_info, value = self.agent(*agent_inputs) dist = self.agent.distribution ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=dist_info) ratio = ratio.clamp_max(1000) # added (to prevent ratio == inf) surr_1 = ratio * advantage clipped_ratio = torch.clamp(ratio, 1.0 - self.ratio_clip, 1.0 + self.ratio_clip) surr_2 = clipped_ratio * advantage surrogate = torch.min(surr_1, surr_2) pi_loss = -valid_mean(surrogate, valid) value_error = 0.5 * (value - return_)**2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = -self.entropy_loss_coeff * entropy loss = pi_loss + value_loss + entropy_loss perplexity = dist.mean_perplexity(dist_info, valid) return loss, entropy, perplexity
def step(self, observation, prev_action, prev_reward): prev_action = self.distribution.to_onehot(prev_action) agent_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) probs, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state) dist_info = DistInfo(prob=probs) if self._mode == 'sample': action = self.distribution.sample(dist_info) elif self._mode == 'eval': action = torch.argmax(probs, dim=-1) # Model handles None, but Buffer does not, make zeros if needed: if self.prev_rnn_state is None: prev_rnn_state = buffer_func(rnn_state, torch.zeros_like) else: prev_rnn_state = self.prev_rnn_state # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage. # (Special case: model should always leave B dimension in.) prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1) agent_info = AgentInfoRnn(dist_info=dist_info, value=value, prev_rnn_state=prev_rnn_state) action, agent_info = buffer_to((action, agent_info), device="cpu") self.advance_rnn_state(rnn_state) # Keep on device. return AgentStep(action=action, agent_info=agent_info)
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.need_reset = np.zeros(len(self.envs), dtype=np.bool) self.done = np.zeros(len( self.envs), dtype=np.bool) # 所有environment的done标志,初始化为"not done" self.temp_observation = buffer_method( self.samples_np.env.observation[0, :len(self.envs)], "copy")
def optimize_agent(self, itr, samples=None, sampler_itr=None): if samples is not None: self.replay_buffer.append_samples(self.samples_to_buffer(samples)) opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) batch_generator = self.replay_buffer.batch_generator(replay_ratio=self.epochs) for batch, init_rnn_state, buffer_wait_time in batch_generator: self.optimizer.zero_grad() init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) dist_info, value, _ = self.agent(*batch.agent_inputs, init_rnn_state) loss, opt_info = self.process_returns(reward=batch.reward, done=batch.done, value_prediction=value, action=batch.action, dist_info=dist_info, old_dist_info=batch.dist_info, opt_info=opt_info) loss.backward() self.optimizer.step() self.clamp_lagrange_multipliers() opt_info.loss.append(loss.item()) opt_info.optim_buffer_wait_time.append(buffer_wait_time) self.update_counter += 1 return opt_info
def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.need_reset = np.zeros(len(self.envs), dtype=np.bool) # e.g. For episodic lives, hold the observation output when done, record # blanks for the rest of the batch, but reinstate the observation to start # next batch. self.temp_observation = buffer_method(self.step_buffer_np.observation, "copy")
def step(self, observation, prev_action, prev_reward, device="cpu"): """ Compute policy's action distribution from inputs, and sample an action. Calls the model to produce mean, log_std, value estimate, and next recurrent state. Moves inputs to device and returns outputs back to CPU, for the sampler. Advances the recurrent state of the agent. (no grad) """ agent_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) mu, log_std, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state) dist_info = DistInfoStd(mean=mu, log_std=log_std) action = self.distribution.sample(dist_info) # Model handles None, but Buffer does not, make zeros if needed: prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func( rnn_state, torch.zeros_like) # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage. # (Special case: model should always leave B dimension in.) prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1) agent_info = AgentInfoRnn(dist_info=dist_info, value=value, prev_rnn_state=prev_rnn_state) action, agent_info = buffer_to((action, agent_info), device=device) self.advance_rnn_state(rnn_state) # Keep on device. return AgentStep(action=action, agent_info=agent_info)
def step(self, observation, prev_action, prev_reward): prev_action = self.distribution.to_onehot(prev_action) agent_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) pi, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state) dist_info = DistInfo(prob=pi) if self.dual_model: int_pi, int_value, int_rnn_state = self.model_int( *agent_inputs, self.prev_int_rnn_state) dist_int_info = DistInfo(prob=int_pi) if self._mode == "eval": action = self.distribution.sample(dist_info) else: action = self.distribution.sample(dist_int_info) else: action = self.distribution.sample(dist_info) # Model handles None, but Buffer does not, make zeros if needed: prev_rnn_state = self.prev_rnn_state or buffer_func( rnn_state, torch.zeros_like) # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage. # (Special case: model should always leave B dimension in.) prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1) if self.dual_model: prev_int_rnn_state = self.prev_int_rnn_state or buffer_func( int_rnn_state, torch.zeros_like) prev_int_rnn_state = buffer_method(prev_int_rnn_state, "transpose", 0, 1) agent_info = AgentInfoRnnTwin( dist_info=dist_info, value=value, prev_rnn_state=prev_rnn_state, dist_int_info=dist_int_info, int_value=int_value, prev_int_rnn_state=prev_int_rnn_state) else: agent_info = AgentInfoRnn(dist_info=dist_info, value=value, prev_rnn_state=prev_rnn_state) action, agent_info = buffer_to((action, agent_info), device="cpu") self.advance_rnn_state(rnn_state) # Keep on device. if self.dual_model: self.advance_int_rnn_state(int_rnn_state) return AgentStep(action=action, agent_info=agent_info)
def loss(self, samples): """ Computes the training loss: policy_loss + value_loss + entropy_loss. Policy loss: log-likelihood of actions * advantages Value loss: 0.5 * (estimated_value - return) ^ 2 Organizes agent inputs from training samples, calls the agent instance to run forward pass on training data, and uses the ``agent.distribution`` to compute likelihoods and entropies. Valid for feedforward or recurrent agents. """ agent_inputs = AgentInputs( observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, ) if self.agent.recurrent: init_rnn_state = samples.agent.agent_info.prev_rnn_state[ 0] # T = 0. # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent( *agent_inputs, init_rnn_state, device=agent_inputs.prev_action.device) else: dist_info, value = self.agent( *agent_inputs, device=agent_inputs.prev_action.device) # TODO: try to compute everyone on device. return_, advantage, valid = self.process_returns(samples) dist = self.agent.distribution logli = dist.log_likelihood(samples.agent.action, dist_info) pi_loss = -valid_mean(logli * advantage, valid) value_error = 0.5 * (value - return_)**2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = -self.entropy_loss_coeff * entropy loss = pi_loss + value_loss + entropy_loss perplexity = dist.mean_perplexity(dist_info, valid) return loss, pi_loss, value_loss, entropy, perplexity
def beta_kl_losses( self, agent_inputs, action, return_, advantage, valid, old_dist_info, c_return, c_advantage, init_rnn_state=None, ): if init_rnn_state is not None: # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") r_dist_info, c_dist_info = self.agent.beta_dist_infos( *agent_inputs, init_rnn_state) else: r_dist_info, c_dist_info = self.agent.beta_dist_infos( *agent_inputs) dist = self.agent.distribution r_ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=r_dist_info) surr_1 = r_ratio * advantage r_clipped_ratio = torch.clamp(r_ratio, 1.0 - self.ratio_clip, 1.0 + self.ratio_clip) surr_2 = r_clipped_ratio * advantage surrogate = torch.min(surr_1, surr_2) beta_r_loss = -valid_mean(surrogate, valid) c_ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=c_dist_info) c_surr_1 = c_ratio * c_advantage c_clipped_ratio = torch.clamp(c_ratio, 1.0 - self.ratio_clip, 1.0 + self.ratio_clip) c_surr_2 = c_clipped_ratio * c_advantage c_surrogate = torch.max(c_surr_1, c_surr_2) beta_c_loss = valid_mean(c_surrogate, valid) return beta_r_loss, beta_c_loss
def rollout_policy(self, steps: int, policy, prev_state: RSSMState): """ Roll out the model with a policy function. :param steps: number of steps to roll out :param policy: RSSMState -> action :param prev_state: RSSM state, size(batch_size, state_size) :return: next states size(time_steps, batch_size, state_size), actions size(time_steps, batch_size, action_size) """ state = prev_state next_states = [] actions = [] state = buffer_method(state, 'detach') for t in range(steps): action, _ = policy(buffer_method(state, 'detach')) state = self.transition_model(action, state) next_states.append(state) actions.append(action) next_states = stack_states(next_states, dim=0) actions = torch.stack(actions, dim=0) return next_states, actions
def to_agent_step(self, output): """Convert the output of the NN model into step info for the agent. """ q, rnn_state = output # q = q.cpu() action = self.distribution.sample(q) prev_rnn_state = self.prev_rnn_state or buffer_func(rnn_state, torch.zeros_like) # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage. # (Special case, model should always leave B dimension in.) prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1) prev_rnn_state, action, q = buffer_to((prev_rnn_state, action, q), device="cpu") agent_info = AgentInfo(q=q, prev_rnn_state=prev_rnn_state) self.advance_rnn_state(rnn_state) # Keep on device. return AgentStep(action=action, agent_info=agent_info)
def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info, init_rnn_state=None): if init_rnn_state is not None: # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state) else: dist_info, value = self.agent(*agent_inputs) # TODO IF MULTIAGENT, reshape things # Just kidding. It seems that the dist.* functions can operate on multi-dimensional tensors # Need to double check this is true, but things seem to be ok # (entropy and likelihood ratios are computed along last dim) dist = self.agent.distribution ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=dist_info) surr_1 = ratio * advantage clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip, 1. + self.ratio_clip) surr_2 = clipped_ratio * advantage surrogate = torch.min(surr_1, surr_2) pi_loss = - valid_mean(surrogate, valid) value_error = 0.5 * (value - return_) ** 2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = - self.entropy_loss_coeff * entropy loss = pi_loss + value_loss + entropy_loss perplexity = dist.mean_perplexity(dist_info, valid) return loss, entropy, perplexity
def step(self, observation, prev_action, prev_reward): agent_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) mu, log_std, value, rnn_state = self.model(*agent_inputs, self.prev_rnn_state) dist_info = DistInfoStd(mean=mu, log_std=log_std) action = self.distribution.sample(dist_info) # Model handles None, but Buffer does not, make zeros if needed: prev_rnn_state = self.prev_rnn_state or buffer_func(rnn_state, torch.zeros_like) # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage. # (Special case: model should always leave B dimension in.) prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1) agent_info = AgentInfoRnn(dist_info=dist_info, value=value, prev_rnn_state=prev_rnn_state) action, agent_info = buffer_to((action, agent_info), device="cpu") self.advance_rnn_state(rnn_state) # Keep on device. return AgentStep(action=action, agent_info=agent_info)
def step(self, observation, prev_action, prev_reward): """Computes Q-values for states/observations and selects actions by epsilon-greedy (no grad). Advances RNN state.""" prev_action = self.distribution.to_onehot(prev_action) agent_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) q, rnn_state = self.model(*agent_inputs, self.prev_rnn_state) # Model handles None. q = q.cpu() action = self.distribution.sample(q) prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func(rnn_state, torch.zeros_like) # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage. # (Special case, model should always leave B dimension in.) prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1) prev_rnn_state = buffer_to(prev_rnn_state, device="cpu") agent_info = AgentInfo(q=q, prev_rnn_state=prev_rnn_state) self.advance_rnn_state(rnn_state) # Keep on device. return AgentStep(action=action, agent_info=agent_info)
def step(self, observation, prev_action, prev_reward): prev_action = self.distribution.to_onehot(prev_action) agent_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) q, rnn_state = self.model(*agent_inputs, self.prev_rnn_state) # Model handles None. q = q.cpu() action = self.distribution.sample(q) prev_rnn_state = self.prev_rnn_state or buffer_func( rnn_state, torch.zeros_like) # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage. # (Special case, model should always leave B dimension in.) prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1) prev_rnn_state = buffer_to(prev_rnn_state, device="cpu") agent_info = AgentInfo(q=q, prev_rnn_state=prev_rnn_state) self.advance_rnn_state(rnn_state) # Keep on device. return AgentStep(action=action, agent_info=agent_info)
def imagine_trajectories(self, _initial_states: RSSMState, batch_t: int, batch_b: int): ############# Imagine trajectories ########## ########### {sτ ; aτ } from each st ########## # no gradient for input (initial) states with torch.no_grad(): initial_states = buffer_method(_initial_states[:-1, :], 'reshape', (batch_t - 1) * (batch_b), -1) # RSSM mean..(2450, 30) # imagine trajectories with a finite horizon H w_transition_represent = self.agent.model.rollout policy = self.agent.model.policy with FreezeParameters(self.world_modules): imagined_states, _ = w_transition_represent.rollout_policy( self.horizon, policy, initial_states) # RSSM mean..(10, 2450, 30) return imagined_states
def optimize_agent(self, itr, samples=None, sampler_itr=None): """ Train the agent, for multiple epochs over minibatches taken from the input samples. Organizes agent inputs from the training data, and moves them to device (e.g. GPU) up front, so that minibatches are formed within device, without further data transfer. """ opt_info = OptInfo(*([] for _ in range(len(OptInfo._fields)))) agent_inputs = AgentInputs( # Move inputs to device once, index there. observation=samples.env.observation, prev_action=samples.agent.prev_action, prev_reward=samples.env.prev_reward, ) agent_inputs = buffer_to(agent_inputs, device=self.agent.device) init_rnn_states = buffer_to(samples.agent.agent_info.prev_rnn_state[0], device=self.agent.device) T, B = samples.env.reward.shape[:2] mb_size = B // self.minibatches for _ in range(self.epochs): for idxs in iterate_mb_idxs(B, mb_size, shuffle=True): self.optimizer.zero_grad() init_rnn_state = buffer_method(init_rnn_states[idxs], "transpose", 0, 1) dist_info, value, _ = self.agent(*agent_inputs[:, idxs], init_rnn_state) loss, opt_info = self.process_returns( samples.env.reward[:, idxs], done=samples.env.done[:, idxs], value_prediction=value.cpu(), action=samples.agent.action[:, idxs], dist_info=dist_info, old_dist_info=samples.agent.agent_info.dist_info[:, idxs], opt_info=opt_info) loss.backward() self.optimizer.step() self.clamp_lagrange_multipliers() opt_info.loss.append(loss.item()) self.update_counter += 1 return opt_info
def step(self, observation, prev_action, prev_reward, device="cpu"): """ Compute policy's action distribution from inputs, and sample an action. Calls the model to produce mean, log_std, value estimate, and next recurrent state. Moves inputs to device and returns outputs back to CPU, for the sampler. Advances the recurrent state of the agent. (no grad) """ agent_inputs = buffer_to((observation, prev_action, prev_reward), device=self.device) mu, log_std, beta, q, pi, rnn_state = self.model( *agent_inputs, self.prev_rnn_state) terminations = torch.bernoulli(beta).bool() # Sample terminations dist_info_omega = DistInfo(prob=pi) new_o = self.sample_option(terminations, dist_info_omega) dist_info = DistInfoStd(mean=mu, log_std=log_std) dist_info_o = DistInfoStd(mean=select_at_indexes(new_o, mu), log_std=select_at_indexes(new_o, log_std)) action = self.distribution.sample(dist_info_o) # Model handles None, but Buffer does not, make zeros if needed: prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func( rnn_state, torch.zeros_like) # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage. # (Special case: model should always leave B dimension in.) prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1) agent_info = AgentInfoOCRnn(dist_info=dist_info, dist_info_o=dist_info_o, q=q, value=(pi * q).sum(-1), termination=terminations, inter_option_dist_info=dist_info_omega, prev_o=self._prev_option, o=new_o, prev_rnn_state=prev_rnn_state) action, agent_info = buffer_to((action, agent_info), device=device) self.advance_rnn_state(rnn_state) # Keep on device. self.advance_oc_state(new_o) return AgentStep(action=action, agent_info=agent_info)
def step(self, observation, prev_action, prev_reward, device="cpu"): prev_option_input = self._prev_option if prev_option_input is None: # Hack to extract previous option prev_option_input = torch.full_like(prev_action, -1) prev_action = self.distribution.to_onehot(prev_action) prev_option_input = self.distribution_omega.to_onehot_with_invalid( prev_option_input) model_inputs = buffer_to( (observation, prev_action, prev_reward, prev_option_input), device=self.device) pi, beta, q, pi_omega, rnn_state = self.model(*model_inputs, self.prev_rnn_state) dist_info_omega = DistInfo(prob=pi_omega) new_o, terminations = self.sample_option( beta, dist_info_omega) # Sample terminations and options # Model handles None, but Buffer does not, make zeros if needed: prev_rnn_state = self.prev_rnn_state if self.prev_rnn_state is not None else buffer_func( rnn_state, torch.zeros_like) # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage. # (Special case: model should always leave B dimension in.) prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1) dist_info = DistInfo(prob=pi) dist_info_o = DistInfo(prob=select_at_indexes(new_o, pi)) action = self.distribution.sample(dist_info_o) agent_info = AgentInfoOCRnn(dist_info=dist_info, dist_info_o=dist_info_o, q=q, value=(pi_omega * q).sum(-1), termination=terminations, dist_info_omega=dist_info_omega, prev_o=self._prev_option, o=new_o, prev_rnn_state=prev_rnn_state) action, agent_info = buffer_to((action, agent_info), device=device) self.advance_oc_state(new_o) self.advance_rnn_state(rnn_state) return AgentStep(action=action, agent_info=agent_info)
def step(self, observation, prev_action, prev_reward): prev_action = self.distribution.to_onehot(prev_action) model_inputs = buffer_to( (observation, prev_action, prev_reward), device=self.device ) pi, value, rnn_state, conv = self.model(*model_inputs, self.prev_rnn_state) if self._act_uniform: pi[:] = 1.0 / pi.shape[-1] # uniform dist_info = DistInfo(prob=pi) action = self.distribution.sample(dist_info) # Model handles None, but Buffer does not, make zeros if needed: prev_rnn_state = self.prev_rnn_state or buffer_func(rnn_state, torch.zeros_like) # Transpose the rnn_state from [N,B,H] --> [B,N,H] for storage. # (Special case: model should always leave B dimension in.) prev_rnn_state = buffer_method(prev_rnn_state, "transpose", 0, 1) agent_info = AgentInfoRnnConv( dist_info=dist_info, value=value, prev_rnn_state=prev_rnn_state, conv=conv if self.store_latent else None, ) # Don't write the extra data. action, agent_info = buffer_to((action, agent_info), device="cpu") self.advance_rnn_state(rnn_state) return AgentStep(action=action, agent_info=agent_info)
def loss(self, samples): """Samples have leading Time and Batch dimentions [T,B,..]. Move all samples to device first, and then slice for sub-sequences. Use same init_rnn_state for agent and target; start both at same t. Warmup the RNN state first on the warmup subsequence, then train on the remaining subsequence. Returns loss (usually use MSE, not Huber), TD-error absolute values, and new sequence-wise priorities, based on weighted sum of max and mean TD-error over the sequence.""" all_observation, all_action, all_reward = buffer_to( (samples.all_observation, samples.all_action, samples.all_reward), device=self.agent.device) wT, bT, nsr = self.warmup_T, self.batch_T, self.n_step_return if wT > 0: warmup_slice = slice(None, wT) # Same for agent and target. warmup_inputs = AgentInputs( observation=all_observation[warmup_slice], prev_action=all_action[warmup_slice], prev_reward=all_reward[warmup_slice], ) agent_slice = slice(wT, wT + bT) agent_inputs = AgentInputs( observation=all_observation[agent_slice], prev_action=all_action[agent_slice], prev_reward=all_reward[agent_slice], ) target_slice = slice(wT, None) # Same start t as agent. (wT + bT + nsr) target_inputs = AgentInputs( observation=all_observation[target_slice], prev_action=all_action[target_slice], prev_reward=all_reward[target_slice], ) action = samples.all_action[wT + 1:wT + 1 + bT] # CPU. return_ = samples.return_[wT:wT + bT] done_n = samples.done_n[wT:wT + bT] if self.store_rnn_state_interval == 0: init_rnn_state = None else: # [B,N,H]-->[N,B,H] cudnn. init_rnn_state = buffer_method(samples.init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") if wT > 0: # Do warmup. with torch.no_grad(): _, target_rnn_state = self.agent.target(*warmup_inputs, init_rnn_state) _, init_rnn_state = self.agent(*warmup_inputs, init_rnn_state) # Recommend aligning sampling batch_T and store_rnn_interval with # warmup_T (and no mid_batch_reset), so that end of trajectory # during warmup leads to new trajectory beginning at start of # training segment of replay. warmup_invalid_mask = valid_from_done(samples.done[:wT])[-1] == 0 # [B] init_rnn_state[:, warmup_invalid_mask] = 0 # [N,B,H] (cudnn) target_rnn_state[:, warmup_invalid_mask] = 0 else: target_rnn_state = init_rnn_state qs, _ = self.agent(*agent_inputs, init_rnn_state) # [T,B,A] q = select_at_indexes(action, qs) with torch.no_grad(): target_qs, _ = self.agent.target(*target_inputs, target_rnn_state) if self.double_dqn: next_qs, _ = self.agent(*target_inputs, init_rnn_state) next_a = torch.argmax(next_qs, dim=-1) target_q = select_at_indexes(next_a, target_qs) else: target_q = torch.max(target_qs, dim=-1).values target_q = target_q[-bT:] # Same length as q. disc = self.discount ** self.n_step_return y = self.value_scale(return_ + (1 - done_n.float()) * disc * self.inv_value_scale(target_q)) # [T,B] delta = y - q losses = 0.5 * delta ** 2 abs_delta = abs(delta) # NOTE: by default, with R2D1, use squared-error loss, delta_clip=None. if self.delta_clip is not None: # Huber loss. b = self.delta_clip * (abs_delta - self.delta_clip / 2) losses = torch.where(abs_delta <= self.delta_clip, losses, b) if self.prioritized_replay: losses *= samples.is_weights.unsqueeze(0) # weights: [B] --> [1,B] valid = valid_from_done(samples.done[wT:]) # 0 after first done. loss = valid_mean(losses, valid) td_abs_errors = abs_delta.detach() if self.delta_clip is not None: td_abs_errors = torch.clamp(td_abs_errors, 0, self.delta_clip) # [T,B] valid_td_abs_errors = td_abs_errors * valid max_d = torch.max(valid_td_abs_errors, dim=0).values mean_d = valid_mean(td_abs_errors, valid, dim=0) # Still high if less valid. priorities = self.pri_eta * max_d + (1 - self.pri_eta) * mean_d # [B] return loss, valid_td_abs_errors, priorities
def loss(self, agent_inputs, action, return_, advantage, valid, old_dist_info, bc_observations, bc_actions, init_rnn_state=None): """ Compute the BC-augmented training loss: policy_loss + value_loss + entropy_loss + bc_loss Policy loss: min(likelhood-ratio * advantage, clip(likelihood_ratio, 1-eps, 1+eps) * advantage) Value loss: 0.5 * (estimated_value - return) ^ 2 BC loss: xent(policy(demo_states), action_labels) Calls the agent to compute forward pass on training data, and uses the ``agent.distribution`` to compute likelihoods and entropies. Valid for feedforward or recurrent agents. """ if init_rnn_state is not None: # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, value, _rnn_state = self.agent(*agent_inputs, init_rnn_state) else: dist_info, value = self.agent(*agent_inputs) dist = self.agent.distribution ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=dist_info) surr_1 = ratio * advantage clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip, 1. + self.ratio_clip) surr_2 = clipped_ratio * advantage surrogate = torch.min(surr_1, surr_2) pi_loss = -valid_mean(surrogate, valid) # TODO: log the value error and correlation value_error = 0.5 * (value - return_)**2 value_loss = self.value_loss_coeff * valid_mean(value_error, valid) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = -self.entropy_loss_coeff * entropy # BC loss (this is the only new part) if self.bc_loss_coeff: if init_rnn_state is not None: raise NotImplementedError("doesn't quite work with RNNs yet") # bc_dist_info, _, _ = self.agent(*bc_agent_inputs, # init_rnn_state) else: # This will break if I have an agent/model that actually needs # the previous action and reward. (IIRC that only includes # recurrent agents in rlpyt, though) dummy_prev_action = bc_actions dummy_prev_reward = torch.zeros(bc_actions.shape[0], device=bc_actions.device) bc_dist_info, _ = self.agent(bc_observations, dummy_prev_action, dummy_prev_reward) expert_ll = dist.log_likelihood(bc_actions, bc_dist_info) # bc_loss = -self.bc_loss_coeff * valid_mean(expert_ll, bc_valid) # TODO: also log BC accuracy (or maybe do it somewhere else, IDK) bc_loss = -self.bc_loss_coeff * expert_ll.mean() else: bc_loss = 0.0 loss = pi_loss + value_loss + entropy_loss + bc_loss perplexity = dist.mean_perplexity(dist_info, valid) return loss, entropy, perplexity
def combined_loss(self, agent_inputs, action, next_obs, ext_return, ext_adv, int_return, int_adv, valid, old_dist_info, init_rnn_state=None): """ Alternative to ``loss`` in PPO. This functions runs ``bonus_call``, performing a forward pass of the intrinsic bonus model and producing a combined reward/advantage stream, and then a combined loss. """ # Run base actor critic model if init_rnn_state is not None: # [B,N,H] --> [N,B,H] (for cudnn). init_rnn_state = buffer_method(init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") dist_info, ext_value, int_value, _rnn_state = self.agent( *agent_inputs, init_rnn_state) else: dist_info, ext_value, int_value = self.agent(*agent_inputs) dist = self.agent.distribution # Second call to bonus model, generates self-supervised bonus model loss # Leading batch dims have already been flattened after entering minibatch bonus_model_inputs = self.agent.extract_bonus_inputs( observation=agent_inputs.observation, next_observation= next_obs, # May be same as observation (dummy placeholder) if algo set next_obs=False action=action) _, bonus_loss = self.agent.bonus_call(bonus_model_inputs) bonus_loss *= self.bonus_loss_coeff # Fuse reward streams by producing combined advantages advantage = self.ext_rew_coeff * ext_adv + self.int_rew_coeff * int_adv # Construct PPO loss ratio = dist.likelihood_ratio(action, old_dist_info=old_dist_info, new_dist_info=dist_info) surr_1 = ratio * advantage clipped_ratio = torch.clamp(ratio, 1. - self.ratio_clip, 1. + self.ratio_clip) surr_2 = clipped_ratio * advantage surrogate = torch.min(surr_1, surr_2) pi_loss = -valid_mean(surrogate, valid) ext_value_error = 0.5 * (ext_value - ext_return)**2 int_value_error = 0.5 * (int_value - int_return)**2 value_loss = self.value_loss_coeff * ( valid_mean(ext_value_error, valid) + int_value_error.mean()) entropy = dist.mean_entropy(dist_info, valid) entropy_loss = -self.entropy_loss_coeff * entropy loss = pi_loss + value_loss + entropy_loss + bonus_loss perplexity = dist.mean_perplexity(dist_info, valid) return loss, entropy, perplexity, pi_loss, value_loss, entropy_loss, bonus_loss
def loss(self, samples: SamplesFromReplay, sample_itr: int, opt_itr: int): """ Compute the loss for a batch of data. This includes computing the model and reward losses on the given data, as well as using the dynamics model to generate additional rollouts, which are used for the actor and value components of the loss. :param samples: samples from replay :param sample_itr: sample iteration :param opt_itr: optimization iteration :return: FloatTensor containing the loss """ model = self.agent.model observation = samples.all_observation[: -1] # [t, t+batch_length+1] -> [t, t+batch_length] action = samples.all_action[ 1:] # [t-1, t+batch_length] -> [t, t+batch_length] reward = samples.all_reward[ 1:] # [t-1, t+batch_length] -> [t, t+batch_length] reward = reward.unsqueeze(2) done = samples.done done = done.unsqueeze(2) # Extract tensors from the Samples object # They all have the batch_t dimension first, but we'll put the batch_b dimension first. # Also, we convert all tensors to floats so they can be fed into our models. lead_dim, batch_t, batch_b, img_shape = infer_leading_dims( observation, 3) # squeeze batch sizes to single batch dimension for imagination roll-out batch_size = batch_t * batch_b # normalize image observation = observation.type(self.type) / 255.0 - 0.5 # embed the image embed = model.observation_encoder(observation) prev_state = model.representation.initial_state(batch_b, device=action.device, dtype=action.dtype) # Rollout model by taking the same series of actions as the real model prior, post = model.rollout.rollout_representation( batch_t, embed, action, prev_state) # Flatten our data (so first dimension is batch_t * batch_b = batch_size) # since we're going to do a new rollout starting from each state visited in each batch. # Compute losses for each component of the model # Model Loss feat = get_feat(post) image_pred = model.observation_decoder(feat) reward_pred = model.reward_model(feat) reward_loss = -torch.mean(reward_pred.log_prob(reward)) image_loss = -torch.mean(image_pred.log_prob(observation)) pcont_loss = torch.tensor(0.) # placeholder if use_pcont = False if self.use_pcont: pcont_pred = model.pcont(feat) pcont_target = self.discount * (1 - done.float()) pcont_loss = -torch.mean(pcont_pred.log_prob(pcont_target)) prior_dist = get_dist(prior) post_dist = get_dist(post) div = torch.mean( torch.distributions.kl.kl_divergence(post_dist, prior_dist)) div = torch.max(div, div.new_full(div.size(), self.free_nats)) model_loss = self.kl_scale * div + reward_loss + image_loss if self.use_pcont: model_loss += self.pcont_scale * pcont_loss # ------------------------------------------ Gradient Barrier ------------------------------------------------ # Don't let gradients pass through to prevent overwriting gradients. # Actor Loss # remove gradients from previously calculated tensors with torch.no_grad(): if self.use_pcont: # "Last step could be terminal." Done in TF2 code, but unclear why flat_post = buffer_method(post[:-1, :], 'reshape', (batch_t - 1) * (batch_b), -1) else: flat_post = buffer_method(post, 'reshape', batch_size, -1) # Rollout the policy for self.horizon steps. Variable names with imag_ indicate this data is imagined not real. # imag_feat shape is [horizon, batch_t * batch_b, feature_size] with FreezeParameters(self.model_modules): imag_dist, _ = model.rollout.rollout_policy( self.horizon, model.policy, flat_post) # Use state features (deterministic and stochastic) to predict the image and reward imag_feat = get_feat( imag_dist) # [horizon, batch_t * batch_b, feature_size] # Assumes these are normal distributions. In the TF code it's be mode, but for a normal distribution mean = mode # If we want to use other distributions we'll have to fix this. # We calculate the target here so no grad necessary # freeze model parameters as only action model gradients needed with FreezeParameters(self.model_modules + self.value_modules): imag_reward = model.reward_model(imag_feat).mean value = model.value_model(imag_feat).mean # Compute the exponential discounted sum of rewards if self.use_pcont: with FreezeParameters([model.pcont]): discount_arr = model.pcont(imag_feat).mean else: discount_arr = self.discount * torch.ones_like(imag_reward) returns = self.compute_return(imag_reward[:-1], value[:-1], discount_arr[:-1], bootstrap=value[-1], lambda_=self.discount_lambda) # Make the top row 1 so the cumulative product starts with discount^0 discount_arr = torch.cat( [torch.ones_like(discount_arr[:1]), discount_arr[1:]]) discount = torch.cumprod(discount_arr[:-1], 0) actor_loss = -torch.mean(discount * returns) # ------------------------------------------ Gradient Barrier ------------------------------------------------ # Don't let gradients pass through to prevent overwriting gradients. # Value Loss # remove gradients from previously calculated tensors with torch.no_grad(): value_feat = imag_feat[:-1].detach() value_discount = discount.detach() value_target = returns.detach() value_pred = model.value_model(value_feat) log_prob = value_pred.log_prob(value_target) value_loss = -torch.mean(value_discount * log_prob.unsqueeze(2)) # ------------------------------------------ Gradient Barrier ------------------------------------------------ # loss info with torch.no_grad(): prior_ent = torch.mean(prior_dist.entropy()) post_ent = torch.mean(post_dist.entropy()) loss_info = LossInfo(model_loss, actor_loss, value_loss, prior_ent, post_ent, div, reward_loss, image_loss, pcont_loss) if self.log_video: if opt_itr == self.train_steps - 1 and sample_itr % self.video_every == 0: self.write_videos(observation, action, image_pred, post, step=sample_itr, n=self.video_summary_b, t=self.video_summary_t) return model_loss, actor_loss, value_loss, loss_info
def loss(self, samples): """ Computes losses for twin Q-values against the min of twin target Q-values and an entropy term. Computes reparameterized policy loss, and loss for tuning entropy weighting, alpha. Input samples have leading batch dimension [B,..] (but not time). """ # SamplesFromReplay = namedarraytuple("SamplesFromReplay", # ["all_observation", "all_action", "all_reward", "return_", "done", "done_n", "init_rnn_state"]) all_observation, all_action, all_reward = buffer_to( (samples.all_observation, samples.all_action, samples.all_reward), device=self.agent.device) # all have (wT + bT + nsr) x bB wT, bT, nsr = self.warmup_T, self.batch_T, self.n_step_return if wT > 0: warmup_slice = slice(None, wT) # Same for agent and target. warmup_inputs = AgentInputs( observation=all_observation[warmup_slice], prev_action=all_action[warmup_slice], prev_reward=all_reward[warmup_slice], ) agent_slice = slice(wT, wT + bT) agent_inputs = AgentInputs( observation=all_observation[agent_slice], prev_action=all_action[agent_slice], prev_reward=all_reward[agent_slice], ) target_slice = slice(wT, None) # Same start t as agent. (wT + bT + nsr) target_inputs = AgentInputs( observation=all_observation[target_slice], prev_action=all_action[target_slice], prev_reward=all_reward[target_slice], ) warmup_action = samples.all_action[1:wT + 1] action = samples.all_action[ wT + 1:wT + 1 + bT] # 'current' action by shifting index by 1 from prev_action return_ = samples.return_[wT:wT + bT] done_n = samples.done_n[wT:wT + bT] if self.store_rnn_state_interval == 0: init_rnn_state = None else: # [B,N,H]-->[N,B,H] cudnn. init_rnn_state = buffer_method(samples.init_rnn_state, "transpose", 0, 1) init_rnn_state = buffer_method(init_rnn_state, "contiguous") if wT > 0: # Do warmup. with torch.no_grad(): _, target_q1_rnn_state, _, target_q2_rnn_state = self.agent.target_q( *warmup_inputs, warmup_action, init_rnn_state, init_rnn_state) _, _, _, init_rnn_state = self.agent.pi( *warmup_inputs, init_rnn_state) # Recommend aligning sampling batch_T and store_rnn_interval with # warmup_T (and no mid_batch_reset), so that end of trajectory # during warmup leads to new trajectory beginning at start of # training segment of replay. warmup_invalid_mask = valid_from_done( samples.done[:wT])[-1] == 0 # [B] init_rnn_state[:, warmup_invalid_mask] = 0 # [N,B,H] (cudnn) target_q1_rnn_state[:, warmup_invalid_mask] = 0 target_q2_rnn_state[:, warmup_invalid_mask] = 0 else: target_q1_rnn_state = init_rnn_state target_q2_rnn_state = init_rnn_state valid = valid_from_done(samples.done)[-bT:] q1, _, q2, _ = self.agent.q(*agent_inputs, action, init_rnn_state, init_rnn_state) with torch.no_grad(): target_action, target_log_pi, _, _ = self.agent.pi( *target_inputs, init_rnn_state) target_q1, _, target_q2, _ = self.agent.target_q( *target_inputs, target_action, target_q1_rnn_state, target_q2_rnn_state) target_q1 = target_q1[-bT:] # Same length as q. target_q2 = target_q2[-bT:] target_log_pi = target_log_pi[-bT:] min_target_q = torch.min(target_q1, target_q2) target_value = min_target_q - self._alpha * target_log_pi disc = self.discount**self.n_step_return y = (self.reward_scale * return_ + (1 - done_n.float()) * disc * target_value) q1_loss = 0.5 * valid_mean((y - q1)**2, valid) q2_loss = 0.5 * valid_mean((y - q2)**2, valid) new_action, log_pi, (pi_mean, pi_log_std), _ = self.agent.pi( *agent_inputs, init_rnn_state) log_target1, _, log_target2, _ = self.agent.q(*agent_inputs, new_action, init_rnn_state, init_rnn_state) min_log_target = torch.min(log_target1, log_target2) prior_log_pi = self.get_action_prior(new_action.cpu()) pi_losses = self._alpha * log_pi - min_log_target - prior_log_pi pi_loss = valid_mean(pi_losses, valid) if self.target_entropy is not None and self.fixed_alpha is None: alpha_losses = -self._log_alpha * (log_pi.detach() + self.target_entropy) alpha_loss = valid_mean(alpha_losses, valid) else: alpha_loss = None losses = (q1_loss, q2_loss, pi_loss, alpha_loss) values = tuple(val.detach() for val in (q1, q2, pi_mean, pi_log_std)) return losses, values