class TorchSACOptimizer(TorchOptimizer): class PolicyValueNetwork(nn.Module): def __init__( self, stream_names: List[str], sensor_specs: List[SensorSpec], network_settings: NetworkSettings, action_spec: ActionSpec, ): super().__init__() num_value_outs = max(sum(action_spec.discrete_branches), 1) num_action_ins = int(action_spec.continuous_size) self.q1_network = ValueNetwork( stream_names, sensor_specs, network_settings, num_action_ins, num_value_outs, ) self.q2_network = ValueNetwork( stream_names, sensor_specs, network_settings, num_action_ins, num_value_outs, ) def forward( self, inputs: List[torch.Tensor], actions: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, q1_grad: bool = True, q2_grad: bool = True, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: """ Performs a forward pass on the value network, which consists of a Q1 and Q2 network. Optionally does not evaluate gradients for either the Q1, Q2, or both. :param inputs: List of observation tensors. :param actions: For a continuous Q function (has actions), tensor of actions. Otherwise, None. :param memories: Initial memories if using memory. Otherwise, None. :param sequence_length: Sequence length if using memory. :param q1_grad: Whether or not to compute gradients for the Q1 network. :param q2_grad: Whether or not to compute gradients for the Q2 network. :return: Tuple of two dictionaries, which both map {reward_signal: Q} for Q1 and Q2, respectively. """ # ExitStack allows us to enter the torch.no_grad() context conditionally with ExitStack() as stack: if not q1_grad: stack.enter_context(torch.no_grad()) # pylint: disable=E1101 q1_out, _ = self.q1_network( inputs, actions=actions, memories=memories, sequence_length=sequence_length, ) with ExitStack() as stack: if not q2_grad: stack.enter_context(torch.no_grad()) # pylint: disable=E1101 q2_out, _ = self.q2_network( inputs, actions=actions, memories=memories, sequence_length=sequence_length, ) return q1_out, q2_out class TargetEntropy(NamedTuple): discrete: List[float] = [] # One per branch continuous: float = 0.0 class LogEntCoef(nn.Module): def __init__(self, discrete, continuous): super().__init__() self.discrete = discrete self.continuous = continuous def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): super().__init__(policy, trainer_params) hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters) self.tau = hyperparameters.tau self.init_entcoef = hyperparameters.init_entcoef self.policy = policy policy_network_settings = policy.network_settings self.tau = hyperparameters.tau self.burn_in_ratio = 0.0 # Non-exposed SAC parameters self.discrete_target_entropy_scale = 0.2 # Roughly equal to e-greedy 0.05 self.continuous_target_entropy_scale = 1.0 self.stream_names = list(self.reward_signals.keys()) # Use to reduce "survivor bonus" when using Curiosity or GAIL. self.gammas = [ _val.gamma for _val in trainer_params.reward_signals.values() ] self.use_dones_in_backup = { name: int(not self.reward_signals[name].ignore_done) for name in self.stream_names } self._action_spec = self.policy.behavior_spec.action_spec self.value_network = TorchSACOptimizer.PolicyValueNetwork( self.stream_names, self.policy.behavior_spec.sensor_specs, policy_network_settings, self._action_spec, ) self.target_network = ValueNetwork( self.stream_names, self.policy.behavior_spec.sensor_specs, policy_network_settings, ) ModelUtils.soft_update(self.policy.actor_critic.critic, self.target_network, 1.0) # We create one entropy coefficient per action, whether discrete or continuous. _disc_log_ent_coef = torch.nn.Parameter( torch.log( torch.as_tensor([self.init_entcoef] * len(self._action_spec.discrete_branches))), requires_grad=True, ) _cont_log_ent_coef = torch.nn.Parameter(torch.log( torch.as_tensor([self.init_entcoef])), requires_grad=True) self._log_ent_coef = TorchSACOptimizer.LogEntCoef( discrete=_disc_log_ent_coef, continuous=_cont_log_ent_coef) _cont_target = ( -1 * self.continuous_target_entropy_scale * np.prod(self._action_spec.continuous_size).astype(np.float32)) _disc_target = [ self.discrete_target_entropy_scale * np.log(i).astype(np.float32) for i in self._action_spec.discrete_branches ] self.target_entropy = TorchSACOptimizer.TargetEntropy( continuous=_cont_target, discrete=_disc_target) policy_params = list( self.policy.actor_critic.network_body.parameters()) + list( self.policy.actor_critic.action_model.parameters()) value_params = list(self.value_network.parameters()) + list( self.policy.actor_critic.critic.parameters()) logger.debug("value_vars") for param in value_params: logger.debug(param.shape) logger.debug("policy_vars") for param in policy_params: logger.debug(param.shape) self.decay_learning_rate = ModelUtils.DecayedValue( hyperparameters.learning_rate_schedule, hyperparameters.learning_rate, 1e-10, self.trainer_settings.max_steps, ) self.policy_optimizer = torch.optim.Adam( policy_params, lr=hyperparameters.learning_rate) self.value_optimizer = torch.optim.Adam( value_params, lr=hyperparameters.learning_rate) self.entropy_optimizer = torch.optim.Adam( self._log_ent_coef.parameters(), lr=hyperparameters.learning_rate) self._move_to_device(default_device()) def _move_to_device(self, device: torch.device) -> None: self._log_ent_coef.to(device) self.target_network.to(device) self.value_network.to(device) def sac_q_loss( self, q1_out: Dict[str, torch.Tensor], q2_out: Dict[str, torch.Tensor], target_values: Dict[str, torch.Tensor], dones: torch.Tensor, rewards: Dict[str, torch.Tensor], loss_masks: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: q1_losses = [] q2_losses = [] # Multiple q losses per stream for i, name in enumerate(q1_out.keys()): q1_stream = q1_out[name].squeeze() q2_stream = q2_out[name].squeeze() with torch.no_grad(): q_backup = rewards[name] + ( (1.0 - self.use_dones_in_backup[name] * dones) * self.gammas[i] * target_values[name]) _q1_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks) _q2_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks) q1_losses.append(_q1_loss) q2_losses.append(_q2_loss) q1_loss = torch.mean(torch.stack(q1_losses)) q2_loss = torch.mean(torch.stack(q2_losses)) return q1_loss, q2_loss def sac_value_loss( self, log_probs: ActionLogProbs, values: Dict[str, torch.Tensor], q1p_out: Dict[str, torch.Tensor], q2p_out: Dict[str, torch.Tensor], loss_masks: torch.Tensor, ) -> torch.Tensor: min_policy_qs = {} with torch.no_grad(): _cont_ent_coef = self._log_ent_coef.continuous.exp() _disc_ent_coef = self._log_ent_coef.discrete.exp() for name in values.keys(): if self._action_spec.discrete_size <= 0: min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name]) else: disc_action_probs = log_probs.all_discrete_tensor.exp() _branched_q1p = ModelUtils.break_into_branches( q1p_out[name] * disc_action_probs, self._action_spec.discrete_branches, ) _branched_q2p = ModelUtils.break_into_branches( q2p_out[name] * disc_action_probs, self._action_spec.discrete_branches, ) _q1p_mean = torch.mean( torch.stack([ torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q1p ]), dim=0, ) _q2p_mean = torch.mean( torch.stack([ torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q2p ]), dim=0, ) min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean) value_losses = [] if self._action_spec.discrete_size <= 0: for name in values.keys(): with torch.no_grad(): v_backup = min_policy_qs[name] - torch.sum( _cont_ent_coef * log_probs.continuous_tensor, dim=1) value_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(values[name], v_backup), loss_masks) value_losses.append(value_loss) else: disc_log_probs = log_probs.all_discrete_tensor branched_per_action_ent = ModelUtils.break_into_branches( disc_log_probs * disc_log_probs.exp(), self._action_spec.discrete_branches, ) # We have to do entropy bonus per action branch branched_ent_bonus = torch.stack([ torch.sum(_disc_ent_coef[i] * _lp, dim=1, keepdim=True) for i, _lp in enumerate(branched_per_action_ent) ]) for name in values.keys(): with torch.no_grad(): v_backup = min_policy_qs[name] - torch.mean( branched_ent_bonus, axis=0) # Add continuous entropy bonus to minimum Q if self._action_spec.continuous_size > 0: v_backup += torch.sum( _cont_ent_coef * log_probs.continuous_tensor, dim=1, keepdim=True, ) value_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(values[name], v_backup.squeeze()), loss_masks, ) value_losses.append(value_loss) value_loss = torch.mean(torch.stack(value_losses)) if torch.isinf(value_loss).any() or torch.isnan(value_loss).any(): raise UnityTrainerException("Inf found") return value_loss def sac_policy_loss( self, log_probs: ActionLogProbs, q1p_outs: Dict[str, torch.Tensor], loss_masks: torch.Tensor, ) -> torch.Tensor: _cont_ent_coef, _disc_ent_coef = ( self._log_ent_coef.continuous, self._log_ent_coef.discrete, ) _cont_ent_coef = _cont_ent_coef.exp() _disc_ent_coef = _disc_ent_coef.exp() mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0) batch_policy_loss = 0 if self._action_spec.discrete_size > 0: disc_log_probs = log_probs.all_discrete_tensor disc_action_probs = disc_log_probs.exp() branched_per_action_ent = ModelUtils.break_into_branches( disc_log_probs * disc_action_probs, self._action_spec.discrete_branches) branched_q_term = ModelUtils.break_into_branches( mean_q1 * disc_action_probs, self._action_spec.discrete_branches) branched_policy_loss = torch.stack( [ torch.sum( _disc_ent_coef[i] * _lp - _qt, dim=1, keepdim=False) for i, (_lp, _qt) in enumerate( zip(branched_per_action_ent, branched_q_term)) ], dim=1, ) batch_policy_loss += torch.sum(branched_policy_loss, dim=1) all_mean_q1 = torch.sum(disc_action_probs * mean_q1, dim=1) else: all_mean_q1 = mean_q1 if self._action_spec.continuous_size > 0: cont_log_probs = log_probs.continuous_tensor batch_policy_loss += torch.mean(_cont_ent_coef * cont_log_probs - all_mean_q1.unsqueeze(1), dim=1) policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) return policy_loss def sac_entropy_loss(self, log_probs: ActionLogProbs, loss_masks: torch.Tensor) -> torch.Tensor: _cont_ent_coef, _disc_ent_coef = ( self._log_ent_coef.continuous, self._log_ent_coef.discrete, ) entropy_loss = 0 if self._action_spec.discrete_size > 0: with torch.no_grad(): # Break continuous into separate branch disc_log_probs = log_probs.all_discrete_tensor branched_per_action_ent = ModelUtils.break_into_branches( disc_log_probs * disc_log_probs.exp(), self._action_spec.discrete_branches, ) target_current_diff_branched = torch.stack( [ torch.sum(_lp, axis=1, keepdim=True) + _te for _lp, _te in zip(branched_per_action_ent, self.target_entropy.discrete) ], axis=1, ) target_current_diff = torch.squeeze( target_current_diff_branched, axis=2) entropy_loss += -1 * ModelUtils.masked_mean( torch.mean(_disc_ent_coef * target_current_diff, axis=1), loss_masks) if self._action_spec.continuous_size > 0: with torch.no_grad(): cont_log_probs = log_probs.continuous_tensor target_current_diff = torch.sum(cont_log_probs + self.target_entropy.continuous, dim=1) # We update all the _cont_ent_coef as one block entropy_loss += -1 * ModelUtils.masked_mean( _cont_ent_coef * target_current_diff, loss_masks) return entropy_loss def _condense_q_streams( self, q_output: Dict[str, torch.Tensor], discrete_actions: torch.Tensor) -> Dict[str, torch.Tensor]: condensed_q_output = {} onehot_actions = ModelUtils.actions_to_onehot( discrete_actions, self._action_spec.discrete_branches) for key, item in q_output.items(): branched_q = ModelUtils.break_into_branches( item, self._action_spec.discrete_branches) only_action_qs = torch.stack([ torch.sum(_act * _q, dim=1, keepdim=True) for _act, _q in zip(onehot_actions, branched_q) ]) condensed_q_output[key] = torch.mean(only_action_qs, dim=0) return condensed_q_output @timed def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: """ Updates model using buffer. :param num_sequences: Number of trajectories in batch. :param batch: Experience mini-batch. :param update_target: Whether or not to update target value network :param reward_signal_batches: Minibatches to use for updating the reward signals, indexed by name. If none, don't update the reward signals. :return: Output from update process. """ rewards = {} for name in self.reward_signals: rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"]) n_obs = len(self.policy.behavior_spec.sensor_specs) current_obs = ObsUtil.from_buffer(batch, n_obs) # Convert to tensors current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] next_obs = ObsUtil.from_buffer_next(batch, n_obs) # Convert to tensors next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs] act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) actions = AgentAction.from_dict(batch) memories_list = [ ModelUtils.list_to_tensor(batch["memory"][i]) for i in range( 0, len(batch["memory"]), self.policy.sequence_length) ] # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. offset = 1 if self.policy.sequence_length > 1 else 0 next_memories_list = [ ModelUtils.list_to_tensor( batch["memory"][i] [self.policy.m_size // 2:]) # only pass value part of memory to target network for i in range(offset, len(batch["memory"]), self.policy.sequence_length) ] if len(memories_list) > 0: memories = torch.stack(memories_list).unsqueeze(0) next_memories = torch.stack(next_memories_list).unsqueeze(0) else: memories = None next_memories = None # Q network memories are 0'ed out, since we don't have them during inference. q_memories = (torch.zeros_like(next_memories) if next_memories is not None else None) # Copy normalizers from policy self.value_network.q1_network.network_body.copy_normalization( self.policy.actor_critic.network_body) self.value_network.q2_network.network_body.copy_normalization( self.policy.actor_critic.network_body) self.target_network.network_body.copy_normalization( self.policy.actor_critic.network_body) ( sampled_actions, log_probs, _, value_estimates, _, ) = self.policy.actor_critic.get_action_stats_and_value( current_obs, masks=act_masks, memories=memories, sequence_length=self.policy.sequence_length, ) cont_sampled_actions = sampled_actions.continuous_tensor cont_actions = actions.continuous_tensor q1p_out, q2p_out = self.value_network( current_obs, cont_sampled_actions, memories=q_memories, sequence_length=self.policy.sequence_length, q2_grad=False, ) q1_out, q2_out = self.value_network( current_obs, cont_actions, memories=q_memories, sequence_length=self.policy.sequence_length, ) if self._action_spec.discrete_size > 0: disc_actions = actions.discrete_tensor q1_stream = self._condense_q_streams(q1_out, disc_actions) q2_stream = self._condense_q_streams(q2_out, disc_actions) else: q1_stream, q2_stream = q1_out, q2_out with torch.no_grad(): target_values, _ = self.target_network( next_obs, memories=next_memories, sequence_length=self.policy.sequence_length, ) masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) dones = ModelUtils.list_to_tensor(batch["done"]) q1_loss, q2_loss = self.sac_q_loss(q1_stream, q2_stream, target_values, dones, rewards, masks) value_loss = self.sac_value_loss(log_probs, value_estimates, q1p_out, q2p_out, masks) policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks) entropy_loss = self.sac_entropy_loss(log_probs, masks) total_value_loss = q1_loss + q2_loss + value_loss decay_lr = self.decay_learning_rate.get_value( self.policy.get_current_step()) ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() ModelUtils.update_learning_rate(self.value_optimizer, decay_lr) self.value_optimizer.zero_grad() total_value_loss.backward() self.value_optimizer.step() ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr) self.entropy_optimizer.zero_grad() entropy_loss.backward() self.entropy_optimizer.step() # Update target network ModelUtils.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau) update_stats = { "Losses/Policy Loss": policy_loss.item(), "Losses/Value Loss": value_loss.item(), "Losses/Q1 Loss": q1_loss.item(), "Losses/Q2 Loss": q2_loss.item(), "Policy/Discrete Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef.discrete)).item(), "Policy/Continuous Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef.continuous)).item(), "Policy/Learning Rate": decay_lr, } return update_stats def update_reward_signals(self, reward_signal_minibatches: Mapping[str, AgentBuffer], num_sequences: int) -> Dict[str, float]: update_stats: Dict[str, float] = {} for name, update_buffer in reward_signal_minibatches.items(): update_stats.update( self.reward_signals[name].update(update_buffer)) return update_stats def get_modules(self): modules = { "Optimizer:value_network": self.value_network, "Optimizer:target_network": self.target_network, "Optimizer:policy_optimizer": self.policy_optimizer, "Optimizer:value_optimizer": self.value_optimizer, "Optimizer:entropy_optimizer": self.entropy_optimizer, } for reward_provider in self.reward_signals.values(): modules.update(reward_provider.get_modules()) return modules
class TorchPPOOptimizer(TorchOptimizer): def __init__(self, policy: TorchPolicy, trainer_settings: TrainerSettings): """ Takes a Policy and a Dict of trainer parameters and creates an Optimizer around the policy. The PPO optimizer has a value estimator and a loss function. :param policy: A TorchPolicy object that will be updated by this PPO Optimizer. :param trainer_params: Trainer parameters dictionary that specifies the properties of the trainer. """ # Create the graph here to give more granular control of the TF graph to the Optimizer. super().__init__(policy, trainer_settings) reward_signal_configs = trainer_settings.reward_signals reward_signal_names = [ key.value for key, _ in reward_signal_configs.items() ] if policy.shared_critic: self._critic = policy.actor else: self._critic = ValueNetwork( reward_signal_names, policy.behavior_spec.observation_specs, network_settings=trainer_settings.network_settings, ) self._critic.to(default_device()) params = list(self.policy.actor.parameters()) + list( self._critic.parameters()) self.hyperparameters: PPOSettings = cast( PPOSettings, trainer_settings.hyperparameters) self.decay_learning_rate = ModelUtils.DecayedValue( self.hyperparameters.learning_rate_schedule, self.hyperparameters.learning_rate, 1e-10, self.trainer_settings.max_steps, ) self.decay_epsilon = ModelUtils.DecayedValue( self.hyperparameters.learning_rate_schedule, self.hyperparameters.epsilon, 0.1, self.trainer_settings.max_steps, ) self.decay_beta = ModelUtils.DecayedValue( self.hyperparameters.learning_rate_schedule, self.hyperparameters.beta, 1e-5, self.trainer_settings.max_steps, ) self.optimizer = torch.optim.Adam( params, lr=self.trainer_settings.hyperparameters.learning_rate) self.stats_name_to_update_name = { "Losses/Value Loss": "value_loss", "Losses/Policy Loss": "policy_loss", } self.stream_names = list(self.reward_signals.keys()) @property def critic(self): return self._critic def ppo_value_loss( self, values: Dict[str, torch.Tensor], old_values: Dict[str, torch.Tensor], returns: Dict[str, torch.Tensor], epsilon: float, loss_masks: torch.Tensor, ) -> torch.Tensor: """ Evaluates value loss for PPO. :param values: Value output of the current network. :param old_values: Value stored with experiences in buffer. :param returns: Computed returns. :param epsilon: Clipping value for value estimate. :param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences. """ value_losses = [] for name, head in values.items(): old_val_tensor = old_values[name] returns_tensor = returns[name] clipped_value_estimate = old_val_tensor + torch.clamp( head - old_val_tensor, -1 * epsilon, epsilon) v_opt_a = (returns_tensor - head)**2 v_opt_b = (returns_tensor - clipped_value_estimate)**2 value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks) value_losses.append(value_loss) value_loss = torch.mean(torch.stack(value_losses)) return value_loss def ppo_policy_loss( self, advantages: torch.Tensor, log_probs: torch.Tensor, old_log_probs: torch.Tensor, loss_masks: torch.Tensor, ) -> torch.Tensor: """ Evaluate PPO policy loss. :param advantages: Computed advantages. :param log_probs: Current policy probabilities :param old_log_probs: Past policy probabilities :param loss_masks: Mask for losses. Used with LSTM to ignore 0'ed out experiences. """ advantage = advantages.unsqueeze(-1) decay_epsilon = self.hyperparameters.epsilon r_theta = torch.exp(log_probs - old_log_probs) p_opt_a = r_theta * advantage p_opt_b = ( torch.clamp(r_theta, 1.0 - decay_epsilon, 1.0 + decay_epsilon) * advantage) policy_loss = -1 * ModelUtils.masked_mean(torch.min(p_opt_a, p_opt_b), loss_masks) return policy_loss @timed def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: """ Performs update on model. :param batch: Batch of experiences. :param num_sequences: Number of sequences to process. :return: Results of update. """ # Get decayed parameters decay_lr = self.decay_learning_rate.get_value( self.policy.get_current_step()) decay_eps = self.decay_epsilon.get_value( self.policy.get_current_step()) decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) returns = {} old_values = {} for name in self.reward_signals: old_values[name] = ModelUtils.list_to_tensor( batch[RewardSignalUtil.value_estimates_key(name)]) returns[name] = ModelUtils.list_to_tensor( batch[RewardSignalUtil.returns_key(name)]) n_obs = len(self.policy.behavior_spec.observation_specs) current_obs = ObsUtil.from_buffer(batch, n_obs) # Convert to tensors current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK]) actions = AgentAction.from_buffer(batch) memories = [ ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i]) for i in range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length) ] if len(memories) > 0: memories = torch.stack(memories).unsqueeze(0) # Get value memories value_memories = [ ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i]) for i in range(0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length) ] if len(value_memories) > 0: value_memories = torch.stack(value_memories).unsqueeze(0) log_probs, entropy = self.policy.evaluate_actions( current_obs, masks=act_masks, actions=actions, memories=memories, seq_len=self.policy.sequence_length, ) values, _ = self.critic.critic_pass( current_obs, memories=value_memories, sequence_length=self.policy.sequence_length, ) old_log_probs = ActionLogProbs.from_buffer(batch).flatten() log_probs = log_probs.flatten() loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool) value_loss = self.ppo_value_loss(values, old_values, returns, decay_eps, loss_masks) policy_loss = self.ppo_policy_loss( ModelUtils.list_to_tensor(batch[BufferKey.ADVANTAGES]), log_probs, old_log_probs, loss_masks, ) loss = (policy_loss + 0.5 * value_loss - decay_bet * ModelUtils.masked_mean(entropy, loss_masks)) # Set optimizer learning rate ModelUtils.update_learning_rate(self.optimizer, decay_lr) self.optimizer.zero_grad() loss.backward() self.optimizer.step() update_stats = { # NOTE: abs() is not technically correct, but matches the behavior in TensorFlow. # TODO: After PyTorch is default, change to something more correct. "Losses/Policy Loss": torch.abs(policy_loss).item(), "Losses/Value Loss": value_loss.item(), "Policy/Learning Rate": decay_lr, "Policy/Epsilon": decay_eps, "Policy/Beta": decay_bet, } for reward_provider in self.reward_signals.values(): update_stats.update(reward_provider.update(batch)) return update_stats def get_modules(self): modules = {"Optimizer": self.optimizer} for reward_provider in self.reward_signals.values(): modules.update(reward_provider.get_modules()) return modules
class TorchSACOptimizer(TorchOptimizer): class PolicyValueNetwork(nn.Module): def __init__( self, stream_names: List[str], observation_shapes: List[Tuple[int, ...]], network_settings: NetworkSettings, act_type: ActionType, act_size: List[int], ): super().__init__() if act_type == ActionType.CONTINUOUS: num_value_outs = 1 num_action_ins = sum(act_size) else: num_value_outs = sum(act_size) num_action_ins = 0 self.q1_network = ValueNetwork( stream_names, observation_shapes, network_settings, num_action_ins, num_value_outs, ) self.q2_network = ValueNetwork( stream_names, observation_shapes, network_settings, num_action_ins, num_value_outs, ) def forward( self, vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], actions: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: q1_out, _ = self.q1_network( vec_inputs, vis_inputs, actions=actions, memories=memories, sequence_length=sequence_length, ) q2_out, _ = self.q2_network( vec_inputs, vis_inputs, actions=actions, memories=memories, sequence_length=sequence_length, ) return q1_out, q2_out def __init__(self, policy: TorchPolicy, trainer_params: TrainerSettings): super().__init__(policy, trainer_params) hyperparameters: SACSettings = cast(SACSettings, trainer_params.hyperparameters) self.tau = hyperparameters.tau self.init_entcoef = hyperparameters.init_entcoef self.policy = policy self.act_size = policy.act_size policy_network_settings = policy.network_settings self.tau = hyperparameters.tau self.burn_in_ratio = 0.0 # Non-exposed SAC parameters self.discrete_target_entropy_scale = 0.2 # Roughly equal to e-greedy 0.05 self.continuous_target_entropy_scale = 1.0 self.stream_names = list(self.reward_signals.keys()) # Use to reduce "survivor bonus" when using Curiosity or GAIL. self.gammas = [_val.gamma for _val in trainer_params.reward_signals.values()] self.use_dones_in_backup = { name: int(not self.reward_signals[name].ignore_done) for name in self.stream_names } self.value_network = TorchSACOptimizer.PolicyValueNetwork( self.stream_names, self.policy.behavior_spec.observation_shapes, policy_network_settings, self.policy.behavior_spec.action_type, self.act_size, ) self.target_network = ValueNetwork( self.stream_names, self.policy.behavior_spec.observation_shapes, policy_network_settings, ) ModelUtils.soft_update( self.policy.actor_critic.critic, self.target_network, 1.0 ) self._log_ent_coef = torch.nn.Parameter( torch.log(torch.as_tensor([self.init_entcoef] * len(self.act_size))), requires_grad=True, ) if self.policy.use_continuous_act: self.target_entropy = torch.as_tensor( -1 * self.continuous_target_entropy_scale * np.prod(self.act_size[0]).astype(np.float32) ) else: self.target_entropy = [ self.discrete_target_entropy_scale * np.log(i).astype(np.float32) for i in self.act_size ] policy_params = list(self.policy.actor_critic.network_body.parameters()) + list( self.policy.actor_critic.distribution.parameters() ) value_params = list(self.value_network.parameters()) + list( self.policy.actor_critic.critic.parameters() ) logger.debug("value_vars") for param in value_params: logger.debug(param.shape) logger.debug("policy_vars") for param in policy_params: logger.debug(param.shape) self.decay_learning_rate = ModelUtils.DecayedValue( hyperparameters.learning_rate_schedule, hyperparameters.learning_rate, 1e-10, self.trainer_settings.max_steps, ) self.policy_optimizer = torch.optim.Adam( policy_params, lr=hyperparameters.learning_rate ) self.value_optimizer = torch.optim.Adam( value_params, lr=hyperparameters.learning_rate ) self.entropy_optimizer = torch.optim.Adam( [self._log_ent_coef], lr=hyperparameters.learning_rate ) self._move_to_device(default_device()) def _move_to_device(self, device: torch.device) -> None: self._log_ent_coef.to(device) self.target_network.to(device) self.value_network.to(device) def sac_q_loss( self, q1_out: Dict[str, torch.Tensor], q2_out: Dict[str, torch.Tensor], target_values: Dict[str, torch.Tensor], dones: torch.Tensor, rewards: Dict[str, torch.Tensor], loss_masks: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: q1_losses = [] q2_losses = [] # Multiple q losses per stream for i, name in enumerate(q1_out.keys()): q1_stream = q1_out[name].squeeze() q2_stream = q2_out[name].squeeze() with torch.no_grad(): q_backup = rewards[name] + ( (1.0 - self.use_dones_in_backup[name] * dones) * self.gammas[i] * target_values[name] ) _q1_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks ) _q2_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks ) q1_losses.append(_q1_loss) q2_losses.append(_q2_loss) q1_loss = torch.mean(torch.stack(q1_losses)) q2_loss = torch.mean(torch.stack(q2_losses)) return q1_loss, q2_loss def sac_value_loss( self, log_probs: torch.Tensor, values: Dict[str, torch.Tensor], q1p_out: Dict[str, torch.Tensor], q2p_out: Dict[str, torch.Tensor], loss_masks: torch.Tensor, discrete: bool, ) -> torch.Tensor: min_policy_qs = {} with torch.no_grad(): _ent_coef = torch.exp(self._log_ent_coef) for name in values.keys(): if not discrete: min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name]) else: action_probs = log_probs.exp() _branched_q1p = ModelUtils.break_into_branches( q1p_out[name] * action_probs, self.act_size ) _branched_q2p = ModelUtils.break_into_branches( q2p_out[name] * action_probs, self.act_size ) _q1p_mean = torch.mean( torch.stack( [ torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q1p ] ), dim=0, ) _q2p_mean = torch.mean( torch.stack( [ torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q2p ] ), dim=0, ) min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean) value_losses = [] if not discrete: for name in values.keys(): with torch.no_grad(): v_backup = min_policy_qs[name] - torch.sum( _ent_coef * log_probs, dim=1 ) value_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(values[name], v_backup), loss_masks ) value_losses.append(value_loss) else: branched_per_action_ent = ModelUtils.break_into_branches( log_probs * log_probs.exp(), self.act_size ) # We have to do entropy bonus per action branch branched_ent_bonus = torch.stack( [ torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True) for i, _lp in enumerate(branched_per_action_ent) ] ) for name in values.keys(): with torch.no_grad(): v_backup = min_policy_qs[name] - torch.mean( branched_ent_bonus, axis=0 ) value_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(values[name], v_backup.squeeze()), loss_masks, ) value_losses.append(value_loss) value_loss = torch.mean(torch.stack(value_losses)) if torch.isinf(value_loss).any() or torch.isnan(value_loss).any(): raise UnityTrainerException("Inf found") return value_loss def sac_policy_loss( self, log_probs: torch.Tensor, q1p_outs: Dict[str, torch.Tensor], loss_masks: torch.Tensor, discrete: bool, ) -> torch.Tensor: _ent_coef = torch.exp(self._log_ent_coef) mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0) if not discrete: mean_q1 = mean_q1.unsqueeze(1) batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1) policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) else: action_probs = log_probs.exp() branched_per_action_ent = ModelUtils.break_into_branches( log_probs * action_probs, self.act_size ) branched_q_term = ModelUtils.break_into_branches( mean_q1 * action_probs, self.act_size ) branched_policy_loss = torch.stack( [ torch.sum(_ent_coef[i] * _lp - _qt, dim=1, keepdim=True) for i, (_lp, _qt) in enumerate( zip(branched_per_action_ent, branched_q_term) ) ] ) batch_policy_loss = torch.squeeze(branched_policy_loss) policy_loss = torch.mean(loss_masks * batch_policy_loss) return policy_loss def sac_entropy_loss( self, log_probs: torch.Tensor, loss_masks: torch.Tensor, discrete: bool ) -> torch.Tensor: if not discrete: with torch.no_grad(): target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1) entropy_loss = -torch.mean( self._log_ent_coef * loss_masks * target_current_diff ) else: with torch.no_grad(): branched_per_action_ent = ModelUtils.break_into_branches( log_probs * log_probs.exp(), self.act_size ) target_current_diff_branched = torch.stack( [ torch.sum(_lp, axis=1, keepdim=True) + _te for _lp, _te in zip( branched_per_action_ent, self.target_entropy ) ], axis=1, ) target_current_diff = torch.squeeze( target_current_diff_branched, axis=2 ) entropy_loss = -1 * ModelUtils.masked_mean( torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks ) return entropy_loss def _condense_q_streams( self, q_output: Dict[str, torch.Tensor], discrete_actions: torch.Tensor ) -> Dict[str, torch.Tensor]: condensed_q_output = {} onehot_actions = ModelUtils.actions_to_onehot(discrete_actions, self.act_size) for key, item in q_output.items(): branched_q = ModelUtils.break_into_branches(item, self.act_size) only_action_qs = torch.stack( [ torch.sum(_act * _q, dim=1, keepdim=True) for _act, _q in zip(onehot_actions, branched_q) ] ) condensed_q_output[key] = torch.mean(only_action_qs, dim=0) return condensed_q_output @timed def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: """ Updates model using buffer. :param num_sequences: Number of trajectories in batch. :param batch: Experience mini-batch. :param update_target: Whether or not to update target value network :param reward_signal_batches: Minibatches to use for updating the reward signals, indexed by name. If none, don't update the reward signals. :return: Output from update process. """ rewards = {} for name in self.reward_signals: rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"]) vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] next_vec_obs = [ModelUtils.list_to_tensor(batch["next_vector_in"])] act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) if self.policy.use_continuous_act: actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1) else: actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) memories_list = [ ModelUtils.list_to_tensor(batch["memory"][i]) for i in range(0, len(batch["memory"]), self.policy.sequence_length) ] # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. offset = 1 if self.policy.sequence_length > 1 else 0 next_memories_list = [ ModelUtils.list_to_tensor( batch["memory"][i][self.policy.m_size // 2 :] ) # only pass value part of memory to target network for i in range(offset, len(batch["memory"]), self.policy.sequence_length) ] if len(memories_list) > 0: memories = torch.stack(memories_list).unsqueeze(0) next_memories = torch.stack(next_memories_list).unsqueeze(0) else: memories = None next_memories = None # Q network memories are 0'ed out, since we don't have them during inference. q_memories = ( torch.zeros_like(next_memories) if next_memories is not None else None ) vis_obs: List[torch.Tensor] = [] next_vis_obs: List[torch.Tensor] = [] if self.policy.use_vis_obs: vis_obs = [] for idx, _ in enumerate( self.policy.actor_critic.network_body.visual_processors ): vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) vis_obs.append(vis_ob) next_vis_ob = ModelUtils.list_to_tensor( batch["next_visual_obs%d" % idx] ) next_vis_obs.append(next_vis_ob) # Copy normalizers from policy self.value_network.q1_network.network_body.copy_normalization( self.policy.actor_critic.network_body ) self.value_network.q2_network.network_body.copy_normalization( self.policy.actor_critic.network_body ) self.target_network.network_body.copy_normalization( self.policy.actor_critic.network_body ) (sampled_actions, log_probs, _, _) = self.policy.sample_actions( vec_obs, vis_obs, masks=act_masks, memories=memories, seq_len=self.policy.sequence_length, all_log_probs=not self.policy.use_continuous_act, ) value_estimates, _ = self.policy.actor_critic.critic_pass( vec_obs, vis_obs, memories, sequence_length=self.policy.sequence_length ) if self.policy.use_continuous_act: squeezed_actions = actions.squeeze(-1) q1p_out, q2p_out = self.value_network( vec_obs, vis_obs, sampled_actions, memories=q_memories, sequence_length=self.policy.sequence_length, ) q1_out, q2_out = self.value_network( vec_obs, vis_obs, squeezed_actions, memories=q_memories, sequence_length=self.policy.sequence_length, ) q1_stream, q2_stream = q1_out, q2_out else: with torch.no_grad(): q1p_out, q2p_out = self.value_network( vec_obs, vis_obs, memories=q_memories, sequence_length=self.policy.sequence_length, ) q1_out, q2_out = self.value_network( vec_obs, vis_obs, memories=q_memories, sequence_length=self.policy.sequence_length, ) q1_stream = self._condense_q_streams(q1_out, actions) q2_stream = self._condense_q_streams(q2_out, actions) with torch.no_grad(): target_values, _ = self.target_network( next_vec_obs, next_vis_obs, memories=next_memories, sequence_length=self.policy.sequence_length, ) masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) use_discrete = not self.policy.use_continuous_act dones = ModelUtils.list_to_tensor(batch["done"]) q1_loss, q2_loss = self.sac_q_loss( q1_stream, q2_stream, target_values, dones, rewards, masks ) value_loss = self.sac_value_loss( log_probs, value_estimates, q1p_out, q2p_out, masks, use_discrete ) policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks, use_discrete) entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete) total_value_loss = q1_loss + q2_loss + value_loss decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() ModelUtils.update_learning_rate(self.value_optimizer, decay_lr) self.value_optimizer.zero_grad() total_value_loss.backward() self.value_optimizer.step() ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr) self.entropy_optimizer.zero_grad() entropy_loss.backward() self.entropy_optimizer.step() # Update target network ModelUtils.soft_update( self.policy.actor_critic.critic, self.target_network, self.tau ) update_stats = { "Losses/Policy Loss": policy_loss.item(), "Losses/Value Loss": value_loss.item(), "Losses/Q1 Loss": q1_loss.item(), "Losses/Q2 Loss": q2_loss.item(), "Policy/Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef)).item(), "Policy/Learning Rate": decay_lr, } return update_stats def update_reward_signals( self, reward_signal_minibatches: Mapping[str, AgentBuffer], num_sequences: int ) -> Dict[str, float]: update_stats: Dict[str, float] = {} for name, update_buffer in reward_signal_minibatches.items(): update_stats.update(self.reward_signals[name].update(update_buffer)) return update_stats def get_modules(self): modules = { "Optimizer:value_network": self.value_network, "Optimizer:target_network": self.target_network, "Optimizer:policy_optimizer": self.policy_optimizer, "Optimizer:value_optimizer": self.value_optimizer, "Optimizer:entropy_optimizer": self.entropy_optimizer, } for reward_provider in self.reward_signals.values(): modules.update(reward_provider.get_modules()) return modules