def sac_entropy_loss( self, log_probs: torch.Tensor, loss_masks: torch.Tensor, discrete: bool ) -> torch.Tensor: if not discrete: with torch.no_grad(): target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1) entropy_loss = -1 * ModelUtils.masked_mean( self._log_ent_coef * target_current_diff, loss_masks ) else: with torch.no_grad(): branched_per_action_ent = ModelUtils.break_into_branches( log_probs * log_probs.exp(), self.act_size ) target_current_diff_branched = torch.stack( [ torch.sum(_lp, axis=1, keepdim=True) + _te for _lp, _te in zip( branched_per_action_ent, self.target_entropy ) ], axis=1, ) target_current_diff = torch.squeeze( target_current_diff_branched, axis=2 ) entropy_loss = -1 * ModelUtils.masked_mean( torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks ) return entropy_loss
def _compare_two_policies(policy1: TorchPolicy, policy2: TorchPolicy) -> None: """ Make sure two policies have the same output for the same input. """ decision_step, _ = mb.create_steps_from_behavior_spec( policy1.behavior_spec, num_agents=1) vec_vis_obs, masks = policy1._split_decision_step(decision_step) vec_obs = [torch.as_tensor(vec_vis_obs.vector_observations)] vis_obs = [ torch.as_tensor(vis_ob) for vis_ob in vec_vis_obs.visual_observations ] memories = torch.as_tensor( policy1.retrieve_memories(list(decision_step.agent_id))).unsqueeze(0) with torch.no_grad(): _, log_probs1, _, _, _ = policy1.sample_actions(vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True) _, log_probs2, _, _, _ = policy2.sample_actions(vec_obs, vis_obs, masks=masks, memories=memories, all_log_probs=True) np.testing.assert_array_equal(log_probs1, log_probs2)
def sac_q_loss( self, q1_out: Dict[str, torch.Tensor], q2_out: Dict[str, torch.Tensor], target_values: Dict[str, torch.Tensor], dones: torch.Tensor, rewards: Dict[str, torch.Tensor], loss_masks: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: q1_losses = [] q2_losses = [] # Multiple q losses per stream for i, name in enumerate(q1_out.keys()): q1_stream = q1_out[name].squeeze() q2_stream = q2_out[name].squeeze() with torch.no_grad(): q_backup = rewards[name] + ( (1.0 - self.use_dones_in_backup[name] * dones) * self.gammas[i] * target_values[name]) _q1_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks) _q2_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks) q1_losses.append(_q1_loss) q2_losses.append(_q2_loss) q1_loss = torch.mean(torch.stack(q1_losses)) q2_loss = torch.mean(torch.stack(q2_losses)) return q1_loss, q2_loss
def evaluate(self, decision_requests: DecisionSteps, global_agent_ids: List[str]) -> Dict[str, Any]: """ Evaluates policy for the agent experiences provided. :param global_agent_ids: :param decision_requests: DecisionStep object containing inputs. :return: Outputs from network as defined by self.inference_dict. """ vec_vis_obs, masks = self._split_decision_step(decision_requests) vec_obs = [torch.as_tensor(vec_vis_obs.vector_observations)] vis_obs = [ torch.as_tensor(vis_ob) for vis_ob in vec_vis_obs.visual_observations ] memories = torch.as_tensor( self.retrieve_memories(global_agent_ids)).unsqueeze(0) run_out = {} with torch.no_grad(): action, log_probs, entropy, memories = self.sample_actions( vec_obs, vis_obs, masks=masks, memories=memories) run_out["action"] = ModelUtils.to_numpy(action) run_out["pre_action"] = ModelUtils.to_numpy(action) # Todo - make pre_action difference run_out["log_probs"] = ModelUtils.to_numpy(log_probs) run_out["entropy"] = ModelUtils.to_numpy(entropy) run_out["learning_rate"] = 0.0 if self.use_recurrent: run_out["memory_out"] = ModelUtils.to_numpy(memories).squeeze(0) return run_out
def evaluate(self, decision_requests: DecisionSteps, global_agent_ids: List[str]) -> Dict[str, Any]: """ Evaluates policy for the agent experiences provided. :param global_agent_ids: :param decision_requests: DecisionStep object containing inputs. :return: Outputs from network as defined by self.inference_dict. """ obs = decision_requests.obs masks = self._extract_masks(decision_requests) tensor_obs = [torch.as_tensor(np_ob) for np_ob in obs] memories = torch.as_tensor( self.retrieve_memories(global_agent_ids)).unsqueeze(0) run_out = {} with torch.no_grad(): action, log_probs, entropy, memories = self.sample_actions( tensor_obs, masks=masks, memories=memories) action_tuple = action.to_action_tuple() run_out["action"] = action_tuple # This is the clipped action which is not saved to the buffer # but is exclusively sent to the environment. env_action_tuple = action.to_action_tuple(clip=self._clip_action) run_out["env_action"] = env_action_tuple run_out["log_probs"] = log_probs.to_log_probs_tuple() run_out["entropy"] = ModelUtils.to_numpy(entropy) run_out["learning_rate"] = 0.0 if self.use_recurrent: run_out["memory_out"] = ModelUtils.to_numpy(memories).squeeze(0) return run_out
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: with torch.no_grad(): estimates, _ = self._discriminator_network.compute_estimate( mini_batch, use_vail_noise=False) return ModelUtils.to_numpy( -torch.log(1.0 - estimates.squeeze(dim=1) * (1.0 - self._discriminator_network.EPSILON)))
def _compare_two_policies(policy1: TorchPolicy, policy2: TorchPolicy) -> None: """ Make sure two policies have the same output for the same input. """ policy1.actor = policy1.actor.to(default_device()) policy2.actor = policy2.actor.to(default_device()) decision_step, _ = mb.create_steps_from_behavior_spec( policy1.behavior_spec, num_agents=1) np_obs = decision_step.obs masks = policy1._extract_masks(decision_step) memories = torch.as_tensor( policy1.retrieve_memories(list(decision_step.agent_id))).unsqueeze(0) tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs] with torch.no_grad(): _, log_probs1, _, _ = policy1.sample_actions(tensor_obs, masks=masks, memories=memories) _, log_probs2, _, _ = policy2.sample_actions(tensor_obs, masks=masks, memories=memories) np.testing.assert_array_equal( ModelUtils.to_numpy(log_probs1.all_discrete_tensor), ModelUtils.to_numpy(log_probs2.all_discrete_tensor), )
def get_zero_entities_mask(entities: List[torch.Tensor]) -> List[torch.Tensor]: """ Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was all zeros (on dimension 2) and 0 otherwise. This is used in the Attention layer to mask the padding observations. """ with torch.no_grad(): if exporting_to_onnx.is_exporting(): with warnings.catch_warnings(): # We ignore a TracerWarning from PyTorch that warns that doing # shape[n].item() will cause the trace to be incorrect (the trace might # not generalize to other inputs) # We ignore this warning because we know the model will always be # run with inputs of the same shape warnings.simplefilter("ignore") # When exporting to ONNX, we want to transpose the entities. This is # because ONNX only support input in NCHW (channel first) format. # Barracuda also expect to get data in NCHW. entities = [ torch.transpose(obs, 2, 1).reshape(-1, obs.shape[1].item(), obs.shape[2].item()) for obs in entities ] # Generate the masking tensors for each entities tensor (mask only if all zeros) key_masks: List[torch.Tensor] = [ (torch.sum(ent**2, axis=2) < 0.01).float() for ent in entities ] return key_masks
def forward( self, vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], actions: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, q1_grad: bool = True, q2_grad: bool = True, ) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: """ Performs a forward pass on the value network, which consists of a Q1 and Q2 network. Optionally does not evaluate gradients for either the Q1, Q2, or both. :param vec_inputs: List of vector observation tensors. :param vis_input: List of visual observation tensors. :param actions: For a continuous Q function (has actions), tensor of actions. Otherwise, None. :param memories: Initial memories if using memory. Otherwise, None. :param sequence_length: Sequence length if using memory. :param q1_grad: Whether or not to compute gradients for the Q1 network. :param q2_grad: Whether or not to compute gradients for the Q2 network. :return: Tuple of two dictionaries, which both map {reward_signal: Q} for Q1 and Q2, respectively. """ # ExitStack allows us to enter the torch.no_grad() context conditionally with ExitStack() as stack: if not q1_grad: stack.enter_context(torch.no_grad()) q1_out, _ = self.q1_network( vec_inputs, vis_inputs, actions=actions, memories=memories, sequence_length=sequence_length, ) with ExitStack() as stack: if not q2_grad: stack.enter_context(torch.no_grad()) q2_out, _ = self.q2_network( vec_inputs, vis_inputs, actions=actions, memories=memories, sequence_length=sequence_length, ) return q1_out, q2_out
def update(self, mini_batch: AgentBuffer) -> Dict[str, np.ndarray]: with torch.no_grad(): target = self._random_network(mini_batch) prediction = self._training_network(mini_batch) loss = torch.mean(torch.sum((prediction - target)**2, dim=1)) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return {"Losses/RND Loss": loss.detach().cpu().numpy()}
def sac_entropy_loss( self, log_probs: ActionLogProbs, loss_masks: torch.Tensor ) -> torch.Tensor: _cont_ent_coef, _disc_ent_coef = ( self._log_ent_coef.continuous, self._log_ent_coef.discrete, ) entropy_loss = 0 if self._action_spec.discrete_size > 0: with torch.no_grad(): # Break continuous into separate branch disc_log_probs = log_probs.all_discrete_tensor branched_per_action_ent = ModelUtils.break_into_branches( disc_log_probs * disc_log_probs.exp(), self._action_spec.discrete_branches, ) target_current_diff_branched = torch.stack( [ torch.sum(_lp, axis=1, keepdim=True) + _te for _lp, _te in zip( branched_per_action_ent, self.target_entropy.discrete ) ], axis=1, ) target_current_diff = torch.squeeze( target_current_diff_branched, axis=2 ) entropy_loss += -1 * ModelUtils.masked_mean( torch.mean(_disc_ent_coef * target_current_diff, axis=1), loss_masks ) if self._action_spec.continuous_size > 0: with torch.no_grad(): cont_log_probs = log_probs.continuous_tensor target_current_diff = torch.sum( cont_log_probs + self.target_entropy.continuous, dim=1 ) # We update all the _cont_ent_coef as one block entropy_loss += -1 * ModelUtils.masked_mean( _cont_ent_coef * target_current_diff, loss_masks ) return entropy_loss
def get_zero_entities_mask( observations: List[torch.Tensor]) -> List[torch.Tensor]: """ Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was all zeros (on dimension 2) and 0 otherwise. This is used in the Attention layer to mask the padding observations. """ with torch.no_grad(): # Generate the masking tensors for each entities tensor (mask only if all zeros) key_masks: List[torch.Tensor] = [ (torch.sum(ent**2, axis=2) < 0.01).float() for ent in observations ] return key_masks
def compute_loss( self, policy_batch: AgentBuffer, expert_batch: AgentBuffer ) -> torch.Tensor: """ Given a policy mini_batch and an expert mini_batch, computes the loss of the discriminator. """ total_loss = torch.zeros(1) stats_dict: Dict[str, np.ndarray] = {} policy_estimate, policy_mu = self.compute_estimate( policy_batch, use_vail_noise=True ) expert_estimate, expert_mu = self.compute_estimate( expert_batch, use_vail_noise=True ) stats_dict["Policy/GAIL Policy Estimate"] = policy_estimate.mean().item() stats_dict["Policy/GAIL Expert Estimate"] = expert_estimate.mean().item() discriminator_loss = -( torch.log(expert_estimate + self.EPSILON) + torch.log(1.0 - policy_estimate + self.EPSILON) ).mean() stats_dict["Losses/GAIL Loss"] = discriminator_loss.item() total_loss += discriminator_loss if self._settings.use_vail: # KL divergence loss (encourage latent representation to be normal) kl_loss = torch.mean( -torch.sum( 1 + (self._z_sigma ** 2).log() - 0.5 * expert_mu ** 2 - 0.5 * policy_mu ** 2 - (self._z_sigma ** 2), dim=1, ) ) vail_loss = self._beta * (kl_loss - self.mutual_information) with torch.no_grad(): self._beta.data = torch.max( self._beta + self.alpha * (kl_loss - self.mutual_information), torch.tensor(0.0), ) total_loss += vail_loss stats_dict["Policy/GAIL Beta"] = self._beta.item() stats_dict["Losses/GAIL KL Loss"] = kl_loss.item() if self.gradient_penalty_weight > 0.0: gradient_magnitude_loss = ( self.gradient_penalty_weight * self.compute_gradient_magnitude(policy_batch, expert_batch) ) stats_dict["Policy/GAIL Grad Mag Loss"] = gradient_magnitude_loss.item() total_loss += gradient_magnitude_loss return total_loss, stats_dict
def update(self, vector_input: torch.Tensor) -> None: with torch.no_grad(): steps_increment = vector_input.size()[0] total_new_steps = self.normalization_steps + steps_increment input_to_old_mean = vector_input - self.running_mean new_mean: torch.Tensor = self.running_mean + ( input_to_old_mean / total_new_steps).sum(0) input_to_new_mean = vector_input - new_mean new_variance = self.running_variance + (input_to_new_mean * input_to_old_mean).sum(0) # Update references. This is much faster than in-place data update. self.running_mean: torch.Tensor = new_mean self.running_variance: torch.Tensor = new_variance self.normalization_steps: torch.Tensor = total_new_steps
def test_predict_minimum_training(): # of 5 numbers, predict index of min np.random.seed(1336) torch.manual_seed(1336) n_k = 5 size = n_k + 1 embedding_size = 64 entity_embeddings = EntityEmbeddings(size, [size], embedding_size, [n_k], concat_self=False) transformer = ResidualSelfAttention(embedding_size) l_layer = LinearEncoder(embedding_size, 2, n_k) loss = torch.nn.CrossEntropyLoss() optimizer = torch.optim.Adam( list(entity_embeddings.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, weight_decay=1e-6, ) batch_size = 200 onehots = ModelUtils.actions_to_onehot( torch.range(0, n_k - 1).unsqueeze(1), [n_k])[0] onehots = onehots.expand((batch_size, -1, -1)) losses = [] for _ in range(400): num = np.random.randint(0, n_k) inp = torch.rand((batch_size, num + 1, 1)) with torch.no_grad(): # create the target : The minimum argmin = torch.argmin(inp, dim=1) argmin = argmin.squeeze() argmin = argmin.detach() sliced_oh = onehots[:, :num + 1] inp = torch.cat([inp, sliced_oh], dim=2) embeddings = entity_embeddings(inp, [inp]) masks = EntityEmbeddings.get_masks([inp]) prediction = transformer(embeddings, masks) prediction = l_layer(prediction) ce = loss(prediction, argmin) losses.append(ce.item()) print(ce.item()) optimizer.zero_grad() ce.backward() optimizer.step() assert np.array(losses[-20:]).mean() < 0.1
def test_simple_transformer_training(): np.random.seed(1336) torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 entity_embeddings = EntityEmbeddings(size, [size], [n_k], embedding_size) transformer = ResidualSelfAttention(embedding_size, [n_k]) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam(list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001) batch_size = 200 point_range = 3 init_error = -1.0 for _ in range(250): center = torch.rand((batch_size, size)) * point_range * 2 - point_range key = torch.rand( (batch_size, n_k, size)) * point_range * 2 - point_range with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum((center.reshape( (batch_size, 1, size)) - key)**2, dim=2) argmin = torch.argmin(distance, dim=1) target = [] for i in range(batch_size): target += [key[i, argmin[i], :]] target = torch.stack(target, dim=0) target = target.detach() embeddings = entity_embeddings(center, [key]) masks = EntityEmbeddings.get_masks([key]) prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target)**2, dim=1) error = torch.mean(error) / 2 if init_error == -1.0: init_error = error.item() else: assert error.item() < init_error print(error.item()) optimizer.zero_grad() error.backward() optimizer.step() assert error.item() < 0.3
def _compare_two_optimizers(opt1: TorchOptimizer, opt2: TorchOptimizer) -> None: trajectory = mb.make_fake_trajectory( length=10, observation_specs=opt1.policy.behavior_spec.observation_specs, action_spec=opt1.policy.behavior_spec.action_spec, max_step_complete=True, ) with torch.no_grad(): _, opt1_val_out, _ = opt1.get_trajectory_value_estimates( trajectory.to_agentbuffer(), trajectory.next_obs, done=False) _, opt2_val_out, _ = opt2.get_trajectory_value_estimates( trajectory.to_agentbuffer(), trajectory.next_obs, done=False) for opt1_val, opt2_val in zip(opt1_val_out.values(), opt2_val_out.values()): np.testing.assert_array_equal(opt1_val, opt2_val)
def test_predict_closest_training(): np.random.seed(1336) torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 entity_embeddings = EntityEmbedding(size, n_k, embedding_size) entity_embeddings.add_self_embedding(size) transformer = ResidualSelfAttention(embedding_size, n_k) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam( list(entity_embeddings.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, weight_decay=1e-6, ) batch_size = 200 for _ in range(200): center = torch.rand((batch_size, size)) key = torch.rand((batch_size, n_k, size)) with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum((center.reshape( (batch_size, 1, size)) - key)**2, dim=2) argmin = torch.argmin(distance, dim=1) target = [] for i in range(batch_size): target += [key[i, argmin[i], :]] target = torch.stack(target, dim=0) target = target.detach() embeddings = entity_embeddings(center, key) masks = get_zero_entities_mask([key]) prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target)**2, dim=1) error = torch.mean(error) / 2 print(error.item()) optimizer.zero_grad() error.backward() optimizer.step() assert error.item() < 0.02
def test_all_masking(mask_value): # We make sure that a mask of all zeros or all ones will not trigger an error np.random.seed(1336) torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 entity_embeddings = EntityEmbedding(size, n_k, embedding_size) entity_embeddings.add_self_embedding(size) transformer = ResidualSelfAttention(embedding_size, n_k) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam( list(entity_embeddings.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, weight_decay=1e-6, ) batch_size = 20 for _ in range(5): center = torch.rand((batch_size, size)) key = torch.rand((batch_size, n_k, size)) with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum((center.reshape( (batch_size, 1, size)) - key)**2, dim=2) argmin = torch.argmin(distance, dim=1) target = [] for i in range(batch_size): target += [key[i, argmin[i], :]] target = torch.stack(target, dim=0) target = target.detach() embeddings = entity_embeddings(center, key) masks = [torch.ones_like(key[:, :, 0]) * mask_value] prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target)**2, dim=1) error = torch.mean(error) / 2 optimizer.zero_grad() error.backward() optimizer.step()
def soft_update(source: nn.Module, target: nn.Module, tau: float) -> None: """ Performs an in-place polyak update of the target module based on the source, by a ratio of tau. Note that source and target modules must have the same parameters, where: target = tau * source + (1-tau) * target :param source: Source module whose parameters will be used. :param target: Target module whose parameters will be updated. :param tau: Percentage of source parameters to use in average. Setting tau to 1 will copy the source parameters to the target. """ with torch.no_grad(): for source_param, target_param in zip(source.parameters(), target.parameters()): target_param.data.mul_(1.0 - tau) torch.add( target_param.data, source_param.data, alpha=tau, out=target_param.data, )
def test_multi_head_attention_training(): np.random.seed(1336) torch.manual_seed(1336) size, n_h, n_k, n_q = 3, 10, 5, 1 embedding_size = 64 mha = MultiHeadAttention(size, size, size, size, n_h, embedding_size) optimizer = torch.optim.Adam(mha.parameters(), lr=0.001) batch_size = 200 point_range = 3 init_error = -1.0 for _ in range(50): query = torch.rand( (batch_size, n_q, size)) * point_range * 2 - point_range key = torch.rand( (batch_size, n_k, size)) * point_range * 2 - point_range value = key with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum((query - key)**2, dim=2) argmin = torch.argmin(distance, dim=1) target = [] for i in range(batch_size): target += [key[i, argmin[i], :]] target = torch.stack(target, dim=0) target = target.detach() prediction, _ = mha.forward(query, key, value) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target)**2, dim=1) error = torch.mean(error) / 2 if init_error == -1.0: init_error = error.item() else: assert error.item() < init_error print(error.item()) optimizer.zero_grad() error.backward() optimizer.step() assert error.item() < 0.5
def sac_value_loss( self, log_probs: torch.Tensor, values: Dict[str, torch.Tensor], q1p_out: Dict[str, torch.Tensor], q2p_out: Dict[str, torch.Tensor], loss_masks: torch.Tensor, discrete: bool, ) -> torch.Tensor: min_policy_qs = {} with torch.no_grad(): _ent_coef = torch.exp(self._log_ent_coef) for name in values.keys(): if not discrete: min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name]) else: action_probs = log_probs.exp() _branched_q1p = ModelUtils.break_into_branches( q1p_out[name] * action_probs, self.act_size ) _branched_q2p = ModelUtils.break_into_branches( q2p_out[name] * action_probs, self.act_size ) _q1p_mean = torch.mean( torch.stack( [ torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q1p ] ), dim=0, ) _q2p_mean = torch.mean( torch.stack( [ torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q2p ] ), dim=0, ) min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean) value_losses = [] if not discrete: for name in values.keys(): with torch.no_grad(): v_backup = min_policy_qs[name] - torch.sum( _ent_coef * log_probs, dim=1 ) value_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(values[name], v_backup), loss_masks ) value_losses.append(value_loss) else: branched_per_action_ent = ModelUtils.break_into_branches( log_probs * log_probs.exp(), self.act_size ) # We have to do entropy bonus per action branch branched_ent_bonus = torch.stack( [ torch.sum(_ent_coef[i] * _lp, dim=1, keepdim=True) for i, _lp in enumerate(branched_per_action_ent) ] ) for name in values.keys(): with torch.no_grad(): v_backup = min_policy_qs[name] - torch.mean( branched_ent_bonus, axis=0 ) value_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(values[name], v_backup.squeeze()), loss_masks, ) value_losses.append(value_loss) value_loss = torch.mean(torch.stack(value_losses)) if torch.isinf(value_loss).any() or torch.isnan(value_loss).any(): raise UnityTrainerException("Inf found") return value_loss
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: """ Updates model using buffer. :param num_sequences: Number of trajectories in batch. :param batch: Experience mini-batch. :param update_target: Whether or not to update target value network :param reward_signal_batches: Minibatches to use for updating the reward signals, indexed by name. If none, don't update the reward signals. :return: Output from update process. """ rewards = {} for name in self.reward_signals: rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"]) vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] next_vec_obs = [ModelUtils.list_to_tensor(batch["next_vector_in"])] act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) if self.policy.use_continuous_act: actions = ModelUtils.list_to_tensor(batch["actions"]).unsqueeze(-1) else: actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) memories_list = [ ModelUtils.list_to_tensor(batch["memory"][i]) for i in range(0, len(batch["memory"]), self.policy.sequence_length) ] # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. offset = 1 if self.policy.sequence_length > 1 else 0 next_memories_list = [ ModelUtils.list_to_tensor( batch["memory"][i][self.policy.m_size // 2 :] ) # only pass value part of memory to target network for i in range(offset, len(batch["memory"]), self.policy.sequence_length) ] if len(memories_list) > 0: memories = torch.stack(memories_list).unsqueeze(0) next_memories = torch.stack(next_memories_list).unsqueeze(0) else: memories = None next_memories = None # Q network memories are 0'ed out, since we don't have them during inference. q_memories = ( torch.zeros_like(next_memories) if next_memories is not None else None ) vis_obs: List[torch.Tensor] = [] next_vis_obs: List[torch.Tensor] = [] if self.policy.use_vis_obs: vis_obs = [] for idx, _ in enumerate( self.policy.actor_critic.network_body.visual_processors ): vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) vis_obs.append(vis_ob) next_vis_ob = ModelUtils.list_to_tensor( batch["next_visual_obs%d" % idx] ) next_vis_obs.append(next_vis_ob) # Copy normalizers from policy self.value_network.q1_network.network_body.copy_normalization( self.policy.actor_critic.network_body ) self.value_network.q2_network.network_body.copy_normalization( self.policy.actor_critic.network_body ) self.target_network.network_body.copy_normalization( self.policy.actor_critic.network_body ) (sampled_actions, _, log_probs, _, _) = self.policy.sample_actions( vec_obs, vis_obs, masks=act_masks, memories=memories, seq_len=self.policy.sequence_length, all_log_probs=not self.policy.use_continuous_act, ) value_estimates, _ = self.policy.actor_critic.critic_pass( vec_obs, vis_obs, memories, sequence_length=self.policy.sequence_length ) if self.policy.use_continuous_act: squeezed_actions = actions.squeeze(-1) # Only need grad for q1, as that is used for policy. q1p_out, q2p_out = self.value_network( vec_obs, vis_obs, sampled_actions, memories=q_memories, sequence_length=self.policy.sequence_length, q2_grad=False, ) q1_out, q2_out = self.value_network( vec_obs, vis_obs, squeezed_actions, memories=q_memories, sequence_length=self.policy.sequence_length, ) q1_stream, q2_stream = q1_out, q2_out else: # For discrete, you don't need to backprop through the Q for the policy q1p_out, q2p_out = self.value_network( vec_obs, vis_obs, memories=q_memories, sequence_length=self.policy.sequence_length, q1_grad=False, q2_grad=False, ) q1_out, q2_out = self.value_network( vec_obs, vis_obs, memories=q_memories, sequence_length=self.policy.sequence_length, ) q1_stream = self._condense_q_streams(q1_out, actions) q2_stream = self._condense_q_streams(q2_out, actions) with torch.no_grad(): target_values, _ = self.target_network( next_vec_obs, next_vis_obs, memories=next_memories, sequence_length=self.policy.sequence_length, ) masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) use_discrete = not self.policy.use_continuous_act dones = ModelUtils.list_to_tensor(batch["done"]) q1_loss, q2_loss = self.sac_q_loss( q1_stream, q2_stream, target_values, dones, rewards, masks ) value_loss = self.sac_value_loss( log_probs, value_estimates, q1p_out, q2p_out, masks, use_discrete ) policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks, use_discrete) entropy_loss = self.sac_entropy_loss(log_probs, masks, use_discrete) total_value_loss = q1_loss + q2_loss + value_loss decay_lr = self.decay_learning_rate.get_value(self.policy.get_current_step()) ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() ModelUtils.update_learning_rate(self.value_optimizer, decay_lr) self.value_optimizer.zero_grad() total_value_loss.backward() self.value_optimizer.step() ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr) self.entropy_optimizer.zero_grad() entropy_loss.backward() self.entropy_optimizer.step() # Update target network ModelUtils.soft_update( self.policy.actor_critic.critic, self.target_network, self.tau ) update_stats = { "Losses/Policy Loss": policy_loss.item(), "Losses/Value Loss": value_loss.item(), "Losses/Q1 Loss": q1_loss.item(), "Losses/Q2 Loss": q2_loss.item(), "Policy/Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef)).item(), "Policy/Learning Rate": decay_lr, } return update_stats
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: """ Updates model using buffer. :param num_sequences: Number of trajectories in batch. :param batch: Experience mini-batch. :param update_target: Whether or not to update target value network :param reward_signal_batches: Minibatches to use for updating the reward signals, indexed by name. If none, don't update the reward signals. :return: Output from update process. """ rewards = {} for name in self.reward_signals: rewards[name] = ModelUtils.list_to_tensor( batch[RewardSignalUtil.rewards_key(name)]) n_obs = len(self.policy.behavior_spec.observation_specs) current_obs = ObsUtil.from_buffer(batch, n_obs) # Convert to tensors current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] next_obs = ObsUtil.from_buffer_next(batch, n_obs) # Convert to tensors next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs] act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK]) actions = AgentAction.from_buffer(batch) memories_list = [ ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i]) for i in range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length) ] # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. value_memories_list = [ ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i]) for i in range(0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length) ] if len(memories_list) > 0: memories = torch.stack(memories_list).unsqueeze(0) value_memories = torch.stack(value_memories_list).unsqueeze(0) else: memories = None value_memories = None # Q and V network memories are 0'ed out, since we don't have them during inference. q_memories = (torch.zeros_like(value_memories) if value_memories is not None else None) # Copy normalizers from policy self.q_network.q1_network.network_body.copy_normalization( self.policy.actor.network_body) self.q_network.q2_network.network_body.copy_normalization( self.policy.actor.network_body) self.target_network.network_body.copy_normalization( self.policy.actor.network_body) self._critic.network_body.copy_normalization( self.policy.actor.network_body) sampled_actions, log_probs, _, _, = self.policy.actor.get_action_and_stats( current_obs, masks=act_masks, memories=memories, sequence_length=self.policy.sequence_length, ) value_estimates, _ = self._critic.critic_pass( current_obs, value_memories, sequence_length=self.policy.sequence_length) cont_sampled_actions = sampled_actions.continuous_tensor cont_actions = actions.continuous_tensor q1p_out, q2p_out = self.q_network( current_obs, cont_sampled_actions, memories=q_memories, sequence_length=self.policy.sequence_length, q2_grad=False, ) q1_out, q2_out = self.q_network( current_obs, cont_actions, memories=q_memories, sequence_length=self.policy.sequence_length, ) if self._action_spec.discrete_size > 0: disc_actions = actions.discrete_tensor q1_stream = self._condense_q_streams(q1_out, disc_actions) q2_stream = self._condense_q_streams(q2_out, disc_actions) else: q1_stream, q2_stream = q1_out, q2_out with torch.no_grad(): # Since we didn't record the next value memories, evaluate one step in the critic to # get them. if value_memories is not None: # Get the first observation in each sequence just_first_obs = [ _obs[::self.policy.sequence_length] for _obs in current_obs ] _, next_value_memories = self._critic.critic_pass( just_first_obs, value_memories, sequence_length=1) else: next_value_memories = None target_values, _ = self.target_network( next_obs, memories=next_value_memories, sequence_length=self.policy.sequence_length, ) masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool) dones = ModelUtils.list_to_tensor(batch[BufferKey.DONE]) q1_loss, q2_loss = self.sac_q_loss(q1_stream, q2_stream, target_values, dones, rewards, masks) value_loss = self.sac_value_loss(log_probs, value_estimates, q1p_out, q2p_out, masks) policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks) entropy_loss = self.sac_entropy_loss(log_probs, masks) total_value_loss = q1_loss + q2_loss if self.policy.shared_critic: policy_loss += value_loss else: total_value_loss += value_loss decay_lr = self.decay_learning_rate.get_value( self.policy.get_current_step()) ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() ModelUtils.update_learning_rate(self.value_optimizer, decay_lr) self.value_optimizer.zero_grad() total_value_loss.backward() self.value_optimizer.step() ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr) self.entropy_optimizer.zero_grad() entropy_loss.backward() self.entropy_optimizer.step() # Update target network ModelUtils.soft_update(self._critic, self.target_network, self.tau) update_stats = { "Losses/Policy Loss": policy_loss.item(), "Losses/Value Loss": value_loss.item(), "Losses/Q1 Loss": q1_loss.item(), "Losses/Q2 Loss": q2_loss.item(), "Policy/Discrete Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef.discrete)).item(), "Policy/Continuous Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef.continuous)).item(), "Policy/Learning Rate": decay_lr, } return update_stats
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: with torch.no_grad(): target = self._random_network(mini_batch) prediction = self._training_network(mini_batch) rewards = torch.sum((prediction - target)**2, dim=1) return rewards.detach().cpu().numpy()
def get_trajectory_and_baseline_value_estimates( self, batch: AgentBuffer, next_obs: List[np.ndarray], next_groupmate_obs: List[List[np.ndarray]], done: bool, agent_id: str = "", ) -> Tuple[Dict[str, np.ndarray], Dict[str, np.ndarray], Dict[str, float], Optional[AgentBufferField], Optional[AgentBufferField], ]: """ Get value estimates, baseline estimates, and memories for a trajectory, in batch form. :param batch: An AgentBuffer that consists of a trajectory. :param next_obs: the next observation (after the trajectory). Used for boostrapping if this is not a termiinal trajectory. :param next_groupmate_obs: the next observations from other members of the group. :param done: Set true if this is a terminal trajectory. :param agent_id: Agent ID of the agent that this trajectory belongs to. :returns: A Tuple of the Value Estimates as a Dict of [name, np.ndarray(trajectory_len)], the baseline estimates as a Dict, the final value estimate as a Dict of [name, float], and optionally (if using memories) an AgentBufferField of initial critic and baseline memories to be used during update. """ n_obs = len(self.policy.behavior_spec.observation_specs) current_obs = ObsUtil.from_buffer(batch, n_obs) groupmate_obs = GroupObsUtil.from_buffer(batch, n_obs) current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] groupmate_obs = [[ ModelUtils.list_to_tensor(obs) for obs in _groupmate_obs ] for _groupmate_obs in groupmate_obs] groupmate_actions = AgentAction.group_from_buffer(batch) next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs] next_obs = [obs.unsqueeze(0) for obs in next_obs] next_groupmate_obs = [ ModelUtils.list_to_tensor_list(_list_obs) for _list_obs in next_groupmate_obs ] # Expand dimensions of next critic obs next_groupmate_obs = [[_obs.unsqueeze(0) for _obs in _list_obs] for _list_obs in next_groupmate_obs] if agent_id in self.value_memory_dict: # The agent_id should always be in both since they are added together _init_value_mem = self.value_memory_dict[agent_id] _init_baseline_mem = self.baseline_memory_dict[agent_id] else: _init_value_mem = (torch.zeros((1, 1, self.critic.memory_size)) if self.policy.use_recurrent else None) _init_baseline_mem = (torch.zeros((1, 1, self.critic.memory_size)) if self.policy.use_recurrent else None) all_obs = ([current_obs] + groupmate_obs if groupmate_obs is not None else [current_obs]) all_next_value_mem: Optional[AgentBufferField] = None all_next_baseline_mem: Optional[AgentBufferField] = None with torch.no_grad(): if self.policy.use_recurrent: ( value_estimates, baseline_estimates, all_next_value_mem, all_next_baseline_mem, next_value_mem, next_baseline_mem, ) = self._evaluate_by_sequence_team( current_obs, groupmate_obs, groupmate_actions, _init_value_mem, _init_baseline_mem, ) else: value_estimates, next_value_mem = self.critic.critic_pass( all_obs, _init_value_mem, sequence_length=batch.num_experiences) groupmate_obs_and_actions = (groupmate_obs, groupmate_actions) baseline_estimates, next_baseline_mem = self.critic.baseline( current_obs, groupmate_obs_and_actions, _init_baseline_mem, sequence_length=batch.num_experiences, ) # Store the memory for the next trajectory self.value_memory_dict[agent_id] = next_value_mem self.baseline_memory_dict[agent_id] = next_baseline_mem all_next_obs = ([next_obs] + next_groupmate_obs if next_groupmate_obs is not None else [next_obs]) next_value_estimates, _ = self.critic.critic_pass(all_next_obs, next_value_mem, sequence_length=1) for name, estimate in baseline_estimates.items(): baseline_estimates[name] = ModelUtils.to_numpy(estimate) for name, estimate in value_estimates.items(): value_estimates[name] = ModelUtils.to_numpy(estimate) # the base line and V shpuld not be on the same done flag for name, estimate in next_value_estimates.items(): next_value_estimates[name] = ModelUtils.to_numpy(estimate) if done: for k in next_value_estimates: if not self.reward_signals[k].ignore_done: next_value_estimates[k][-1] = 0.0 return ( value_estimates, baseline_estimates, next_value_estimates, all_next_value_mem, all_next_baseline_mem, )
def get_trajectory_value_estimates( self, batch: AgentBuffer, next_obs: List[np.ndarray], done: bool, agent_id: str = "", ) -> Tuple[Dict[str, np.ndarray], Dict[str, float], Optional[AgentBufferField]]: """ Get value estimates and memories for a trajectory, in batch form. :param batch: An AgentBuffer that consists of a trajectory. :param next_obs: the next observation (after the trajectory). Used for boostrapping if this is not a termiinal trajectory. :param done: Set true if this is a terminal trajectory. :param agent_id: Agent ID of the agent that this trajectory belongs to. :returns: A Tuple of the Value Estimates as a Dict of [name, np.ndarray(trajectory_len)], the final value estimate as a Dict of [name, float], and optionally (if using memories) an AgentBufferField of initial critic memories to be used during update. """ n_obs = len(self.policy.behavior_spec.observation_specs) if agent_id in self.critic_memory_dict: memory = self.critic_memory_dict[agent_id] else: memory = ( torch.zeros((1, 1, self.critic.memory_size)) if self.policy.use_recurrent else None ) # Convert to tensors current_obs = [ ModelUtils.list_to_tensor(obs) for obs in ObsUtil.from_buffer(batch, n_obs) ] next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs] next_obs = [obs.unsqueeze(0) for obs in next_obs] # If we're using LSTM, we want to get all the intermediate memories. all_next_memories: Optional[AgentBufferField] = None # To prevent memory leak and improve performance, evaluate with no_grad. with torch.no_grad(): if self.policy.use_recurrent: ( value_estimates, all_next_memories, next_memory, ) = self._evaluate_by_sequence(current_obs, memory) else: value_estimates, next_memory = self.critic.critic_pass( current_obs, memory, sequence_length=batch.num_experiences ) # Store the memory for the next trajectory. This should NOT have a gradient. self.critic_memory_dict[agent_id] = next_memory next_value_estimate, _ = self.critic.critic_pass( next_obs, next_memory, sequence_length=1 ) for name, estimate in value_estimates.items(): value_estimates[name] = ModelUtils.to_numpy(estimate) next_value_estimate[name] = ModelUtils.to_numpy(next_value_estimate[name]) if done: for k in next_value_estimate: if not self.reward_signals[k].ignore_done: next_value_estimate[k] = 0.0 if agent_id in self.critic_memory_dict: self.critic_memory_dict.pop(agent_id) return value_estimates, next_value_estimate, all_next_memories
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: """ Updates model using buffer. :param num_sequences: Number of trajectories in batch. :param batch: Experience mini-batch. :param update_target: Whether or not to update target value network :param reward_signal_batches: Minibatches to use for updating the reward signals, indexed by name. If none, don't update the reward signals. :return: Output from update process. """ rewards = {} for name in self.reward_signals: rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"]) n_obs = len(self.policy.behavior_spec.sensor_specs) current_obs = ObsUtil.from_buffer(batch, n_obs) # Convert to tensors current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] next_obs = ObsUtil.from_buffer_next(batch, n_obs) # Convert to tensors next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs] act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) actions = AgentAction.from_dict(batch) memories_list = [ ModelUtils.list_to_tensor(batch["memory"][i]) for i in range( 0, len(batch["memory"]), self.policy.sequence_length) ] # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. offset = 1 if self.policy.sequence_length > 1 else 0 next_memories_list = [ ModelUtils.list_to_tensor( batch["memory"][i] [self.policy.m_size // 2:]) # only pass value part of memory to target network for i in range(offset, len(batch["memory"]), self.policy.sequence_length) ] if len(memories_list) > 0: memories = torch.stack(memories_list).unsqueeze(0) next_memories = torch.stack(next_memories_list).unsqueeze(0) else: memories = None next_memories = None # Q network memories are 0'ed out, since we don't have them during inference. q_memories = (torch.zeros_like(next_memories) if next_memories is not None else None) # Copy normalizers from policy self.value_network.q1_network.network_body.copy_normalization( self.policy.actor_critic.network_body) self.value_network.q2_network.network_body.copy_normalization( self.policy.actor_critic.network_body) self.target_network.network_body.copy_normalization( self.policy.actor_critic.network_body) ( sampled_actions, log_probs, _, value_estimates, _, ) = self.policy.actor_critic.get_action_stats_and_value( current_obs, masks=act_masks, memories=memories, sequence_length=self.policy.sequence_length, ) cont_sampled_actions = sampled_actions.continuous_tensor cont_actions = actions.continuous_tensor q1p_out, q2p_out = self.value_network( current_obs, cont_sampled_actions, memories=q_memories, sequence_length=self.policy.sequence_length, q2_grad=False, ) q1_out, q2_out = self.value_network( current_obs, cont_actions, memories=q_memories, sequence_length=self.policy.sequence_length, ) if self._action_spec.discrete_size > 0: disc_actions = actions.discrete_tensor q1_stream = self._condense_q_streams(q1_out, disc_actions) q2_stream = self._condense_q_streams(q2_out, disc_actions) else: q1_stream, q2_stream = q1_out, q2_out with torch.no_grad(): target_values, _ = self.target_network( next_obs, memories=next_memories, sequence_length=self.policy.sequence_length, ) masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) dones = ModelUtils.list_to_tensor(batch["done"]) q1_loss, q2_loss = self.sac_q_loss(q1_stream, q2_stream, target_values, dones, rewards, masks) value_loss = self.sac_value_loss(log_probs, value_estimates, q1p_out, q2p_out, masks) policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks) entropy_loss = self.sac_entropy_loss(log_probs, masks) total_value_loss = q1_loss + q2_loss + value_loss decay_lr = self.decay_learning_rate.get_value( self.policy.get_current_step()) ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() ModelUtils.update_learning_rate(self.value_optimizer, decay_lr) self.value_optimizer.zero_grad() total_value_loss.backward() self.value_optimizer.step() ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr) self.entropy_optimizer.zero_grad() entropy_loss.backward() self.entropy_optimizer.step() # Update target network ModelUtils.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau) update_stats = { "Losses/Policy Loss": policy_loss.item(), "Losses/Value Loss": value_loss.item(), "Losses/Q1 Loss": q1_loss.item(), "Losses/Q2 Loss": q2_loss.item(), "Policy/Discrete Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef.discrete)).item(), "Policy/Continuous Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef.continuous)).item(), "Policy/Learning Rate": decay_lr, } return update_stats
def sac_value_loss( self, log_probs: ActionLogProbs, values: Dict[str, torch.Tensor], q1p_out: Dict[str, torch.Tensor], q2p_out: Dict[str, torch.Tensor], loss_masks: torch.Tensor, ) -> torch.Tensor: min_policy_qs = {} with torch.no_grad(): _cont_ent_coef = self._log_ent_coef.continuous.exp() _disc_ent_coef = self._log_ent_coef.discrete.exp() for name in values.keys(): if self._action_spec.discrete_size <= 0: min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name]) else: disc_action_probs = log_probs.all_discrete_tensor.exp() _branched_q1p = ModelUtils.break_into_branches( q1p_out[name] * disc_action_probs, self._action_spec.discrete_branches, ) _branched_q2p = ModelUtils.break_into_branches( q2p_out[name] * disc_action_probs, self._action_spec.discrete_branches, ) _q1p_mean = torch.mean( torch.stack([ torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q1p ]), dim=0, ) _q2p_mean = torch.mean( torch.stack([ torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q2p ]), dim=0, ) min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean) value_losses = [] if self._action_spec.discrete_size <= 0: for name in values.keys(): with torch.no_grad(): v_backup = min_policy_qs[name] - torch.sum( _cont_ent_coef * log_probs.continuous_tensor, dim=1) value_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(values[name], v_backup), loss_masks) value_losses.append(value_loss) else: disc_log_probs = log_probs.all_discrete_tensor branched_per_action_ent = ModelUtils.break_into_branches( disc_log_probs * disc_log_probs.exp(), self._action_spec.discrete_branches, ) # We have to do entropy bonus per action branch branched_ent_bonus = torch.stack([ torch.sum(_disc_ent_coef[i] * _lp, dim=1, keepdim=True) for i, _lp in enumerate(branched_per_action_ent) ]) for name in values.keys(): with torch.no_grad(): v_backup = min_policy_qs[name] - torch.mean( branched_ent_bonus, axis=0) # Add continuous entropy bonus to minimum Q if self._action_spec.continuous_size > 0: v_backup += torch.sum( _cont_ent_coef * log_probs.continuous_tensor, dim=1, keepdim=True, ) value_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(values[name], v_backup.squeeze()), loss_masks, ) value_losses.append(value_loss) value_loss = torch.mean(torch.stack(value_losses)) if torch.isinf(value_loss).any() or torch.isnan(value_loss).any(): raise UnityTrainerException("Inf found") return value_loss
def evaluate(self, mini_batch: AgentBuffer) -> np.ndarray: with torch.no_grad(): rewards = ModelUtils.to_numpy( self._network.compute_reward(mini_batch)) rewards = np.minimum(rewards, 1.0 / self.strength) return rewards * self._has_updated_once