def sac_policy_loss( self, log_probs: torch.Tensor, q1p_outs: Dict[str, torch.Tensor], loss_masks: torch.Tensor, discrete: bool, ) -> torch.Tensor: _ent_coef = torch.exp(self._log_ent_coef) mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0) if not discrete: mean_q1 = mean_q1.unsqueeze(1) batch_policy_loss = torch.mean(_ent_coef * log_probs - mean_q1, dim=1) policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) else: action_probs = log_probs.exp() branched_per_action_ent = ModelUtils.break_into_branches( log_probs * action_probs, self.act_size ) branched_q_term = ModelUtils.break_into_branches( mean_q1 * action_probs, self.act_size ) branched_policy_loss = torch.stack( [ torch.sum(_ent_coef[i] * _lp - _qt, dim=1, keepdim=True) for i, (_lp, _qt) in enumerate( zip(branched_per_action_ent, branched_q_term) ) ], dim=1, ) batch_policy_loss = torch.squeeze(branched_policy_loss) policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) return policy_loss
def sac_q_loss( self, q1_out: Dict[str, torch.Tensor], q2_out: Dict[str, torch.Tensor], target_values: Dict[str, torch.Tensor], dones: torch.Tensor, rewards: Dict[str, torch.Tensor], loss_masks: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: q1_losses = [] q2_losses = [] # Multiple q losses per stream for i, name in enumerate(q1_out.keys()): q1_stream = q1_out[name].squeeze() q2_stream = q2_out[name].squeeze() with torch.no_grad(): q_backup = rewards[name] + ( (1.0 - self.use_dones_in_backup[name] * dones) * self.gammas[i] * target_values[name]) _q1_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(q_backup, q1_stream), loss_masks) _q2_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(q_backup, q2_stream), loss_masks) q1_losses.append(_q1_loss) q2_losses.append(_q2_loss) q1_loss = torch.mean(torch.stack(q1_losses)) q2_loss = torch.mean(torch.stack(q2_losses)) return q1_loss, q2_loss
def forward( self, obs_only: List[List[torch.Tensor]], obs: List[List[torch.Tensor]], actions: List[AgentAction], memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns sampled actions. If memory is enabled, return the memories as well. :param obs_only: Observations to be processed that do not have corresponding actions. These are encoded with the obs_encoder. :param obs: Observations to be processed that do have corresponding actions. After concatenation with actions, these are processed with obs_action_encoder. :param actions: After concatenation with obs, these are processed with obs_action_encoder. :param memories: If using memory, a Tensor of initial memories. :param sequence_length: If using memory, the sequence length. """ self_attn_masks = [] self_attn_inputs = [] concat_f_inp = [] if obs: obs_attn_mask = self._get_masks_from_nans(obs) obs = self._copy_and_remove_nans_from_obs(obs, obs_attn_mask) for inputs, action in zip(obs, actions): encoded = self.observation_encoder(inputs) cat_encodes = [ encoded, action.to_flat(self.action_spec.discrete_branches), ] concat_f_inp.append(torch.cat(cat_encodes, dim=1)) f_inp = torch.stack(concat_f_inp, dim=1) self_attn_masks.append(obs_attn_mask) self_attn_inputs.append(self.obs_action_encoder(None, f_inp)) concat_encoded_obs = [] if obs_only: obs_only_attn_mask = self._get_masks_from_nans(obs_only) obs_only = self._copy_and_remove_nans_from_obs( obs_only, obs_only_attn_mask) for inputs in obs_only: encoded = self.observation_encoder(inputs) concat_encoded_obs.append(encoded) g_inp = torch.stack(concat_encoded_obs, dim=1) self_attn_masks.append(obs_only_attn_mask) self_attn_inputs.append(self.obs_encoder(None, g_inp)) encoded_entity = torch.cat(self_attn_inputs, dim=1) encoded_state = self.self_attn(encoded_entity, self_attn_masks) encoding = self.linear_encoder(encoded_state) if self.use_lstm: # Resize to (batch, sequence length, encoding size) encoding = encoding.reshape([-1, sequence_length, self.h_size]) encoding, memories = self.lstm(encoding, memories) encoding = encoding.reshape([-1, self.m_size // 2]) return encoding, memories
def _behavioral_cloning_loss( self, selected_actions: AgentAction, log_probs: ActionLogProbs, expert_actions: torch.Tensor, ) -> torch.Tensor: bc_loss = 0 if self.policy.behavior_spec.action_spec.continuous_size > 0: bc_loss += torch.nn.functional.mse_loss( selected_actions.continuous_tensor, expert_actions.continuous_tensor ) if self.policy.behavior_spec.action_spec.discrete_size > 0: one_hot_expert_actions = ModelUtils.actions_to_onehot( expert_actions.discrete_tensor, self.policy.behavior_spec.action_spec.discrete_branches, ) log_prob_branches = ModelUtils.break_into_branches( log_probs.all_discrete_tensor, self.policy.behavior_spec.action_spec.discrete_branches, ) bc_loss += torch.mean( torch.stack( [ torch.sum( -torch.nn.functional.log_softmax(log_prob_branch, dim=1) * expert_actions_branch, dim=1, ) for log_prob_branch, expert_actions_branch in zip( log_prob_branches, one_hot_expert_actions ) ] ) ) return bc_loss
def sample_actions( self, vec_obs: List[torch.Tensor], vis_obs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, seq_len: int = 1, all_log_probs: bool = False, ) -> Tuple[ torch.Tensor, torch.Tensor, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor ]: """ :param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action. """ dists, value_heads, memories = self.actor_critic.get_dist_and_value( vec_obs, vis_obs, masks, memories, seq_len ) action_list = self.actor_critic.sample_action(dists) log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy( action_list, dists ) actions = torch.stack(action_list, dim=-1) if self.use_continuous_act: actions = actions[:, :, 0] else: actions = actions[:, 0, :] return ( actions, all_logs if all_log_probs else log_probs, entropies, value_heads, memories, )
def forward( self, vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, int, int, int, int]: """ Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs. """ dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1) if self.action_spec.is_continuous(): action_list = self.sample_action(dists) action_out = torch.stack(action_list, dim=-1) if self._clip_action_on_export: action_out = torch.clamp(action_out, -3, 3) / 3 else: action_out = torch.cat([dist.all_log_prob() for dist in dists], dim=1) return ( action_out, self.version_number, torch.Tensor([self.network_body.memory_size]), self.is_continuous_int, self.act_size_vector, )
def sac_entropy_loss( self, log_probs: torch.Tensor, loss_masks: torch.Tensor, discrete: bool ) -> torch.Tensor: if not discrete: with torch.no_grad(): target_current_diff = torch.sum(log_probs + self.target_entropy, dim=1) entropy_loss = -1 * ModelUtils.masked_mean( self._log_ent_coef * target_current_diff, loss_masks ) else: with torch.no_grad(): branched_per_action_ent = ModelUtils.break_into_branches( log_probs * log_probs.exp(), self.act_size ) target_current_diff_branched = torch.stack( [ torch.sum(_lp, axis=1, keepdim=True) + _te for _lp, _te in zip( branched_per_action_ent, self.target_entropy ) ], axis=1, ) target_current_diff = torch.squeeze( target_current_diff_branched, axis=2 ) entropy_loss = -1 * ModelUtils.masked_mean( torch.mean(self._log_ent_coef * target_current_diff, axis=1), loss_masks ) return entropy_loss
def ppo_value_loss( self, values: Dict[str, torch.Tensor], old_values: Dict[str, torch.Tensor], returns: Dict[str, torch.Tensor], epsilon: float, loss_masks: torch.Tensor, ) -> torch.Tensor: """ Evaluates value loss for PPO. :param values: Value output of the current network. :param old_values: Value stored with experiences in buffer. :param returns: Computed returns. :param epsilon: Clipping value for value estimate. :param loss_mask: Mask for losses. Used with LSTM to ignore 0'ed out experiences. """ value_losses = [] for name, head in values.items(): old_val_tensor = old_values[name] returns_tensor = returns[name] clipped_value_estimate = old_val_tensor + torch.clamp( head - old_val_tensor, -1 * epsilon, epsilon) v_opt_a = (returns_tensor - head)**2 v_opt_b = (returns_tensor - clipped_value_estimate)**2 value_loss = ModelUtils.masked_mean(torch.max(v_opt_a, v_opt_b), loss_masks) value_losses.append(value_loss) value_loss = torch.mean(torch.stack(value_losses)) return value_loss
def test_evaluate_actions(rnn, visual, discrete): policy = create_policy_mock( TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual ) buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size) act_masks = ModelUtils.list_to_tensor(buffer[BufferKey.ACTION_MASK]) agent_action = AgentAction.from_buffer(buffer) np_obs = ObsUtil.from_buffer(buffer, len(policy.behavior_spec.observation_specs)) tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs] memories = [ ModelUtils.list_to_tensor(buffer[BufferKey.MEMORY][i]) for i in range(0, len(buffer[BufferKey.MEMORY]), policy.sequence_length) ] if len(memories) > 0: memories = torch.stack(memories).unsqueeze(0) log_probs, entropy, values = policy.evaluate_actions( tensor_obs, masks=act_masks, actions=agent_action, memories=memories, seq_len=policy.sequence_length, ) if discrete: _size = policy.behavior_spec.action_spec.discrete_size else: _size = policy.behavior_spec.action_spec.continuous_size assert log_probs.flatten().shape == (64, _size) assert entropy.shape == (64,) for val in values.values(): assert val.shape == (64,)
def test_sample_actions(rnn, visual, discrete): policy = create_policy_mock( TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual ) buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size) act_masks = ModelUtils.list_to_tensor(buffer[BufferKey.ACTION_MASK]) np_obs = ObsUtil.from_buffer(buffer, len(policy.behavior_spec.observation_specs)) tensor_obs = [ModelUtils.list_to_tensor(obs) for obs in np_obs] memories = [ ModelUtils.list_to_tensor(buffer[BufferKey.MEMORY][i]) for i in range(0, len(buffer[BufferKey.MEMORY]), policy.sequence_length) ] if len(memories) > 0: memories = torch.stack(memories).unsqueeze(0) (sampled_actions, log_probs, entropies, memories) = policy.sample_actions( tensor_obs, masks=act_masks, memories=memories, seq_len=policy.sequence_length ) if discrete: assert log_probs.all_discrete_tensor.shape == ( 64, sum(policy.behavior_spec.action_spec.discrete_branches), ) else: assert log_probs.continuous_tensor.shape == ( 64, policy.behavior_spec.action_spec.continuous_size, ) assert entropies.shape == (64,) if rnn: assert memories.shape == (1, 1, policy.m_size)
def discrete_tensor(self) -> torch.Tensor: """ Returns the discrete action list as a stacked tensor """ if self.discrete_list is not None and len(self.discrete_list) > 0: return torch.stack(self.discrete_list, dim=-1) else: return torch.empty(0)
def sac_policy_loss( self, log_probs: ActionLogProbs, q1p_outs: Dict[str, torch.Tensor], loss_masks: torch.Tensor, ) -> torch.Tensor: _cont_ent_coef, _disc_ent_coef = ( self._log_ent_coef.continuous, self._log_ent_coef.discrete, ) _cont_ent_coef = _cont_ent_coef.exp() _disc_ent_coef = _disc_ent_coef.exp() mean_q1 = torch.mean(torch.stack(list(q1p_outs.values())), axis=0) batch_policy_loss = 0 if self._action_spec.discrete_size > 0: disc_log_probs = log_probs.all_discrete_tensor disc_action_probs = disc_log_probs.exp() branched_per_action_ent = ModelUtils.break_into_branches( disc_log_probs * disc_action_probs, self._action_spec.discrete_branches) branched_q_term = ModelUtils.break_into_branches( mean_q1 * disc_action_probs, self._action_spec.discrete_branches) branched_policy_loss = torch.stack( [ torch.sum( _disc_ent_coef[i] * _lp - _qt, dim=1, keepdim=False) for i, (_lp, _qt) in enumerate( zip(branched_per_action_ent, branched_q_term)) ], dim=1, ) batch_policy_loss += torch.sum(branched_policy_loss, dim=1) all_mean_q1 = torch.sum(disc_action_probs * mean_q1, dim=1) else: all_mean_q1 = mean_q1 if self._action_spec.continuous_size > 0: cont_log_probs = log_probs.continuous_tensor batch_policy_loss += torch.mean(_cont_ent_coef * cont_log_probs - all_mean_q1.unsqueeze(1), dim=1) policy_loss = ModelUtils.masked_mean(batch_policy_loss, loss_masks) return policy_loss
def test_sample_actions(rnn, visual, discrete): policy = create_policy_mock(TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual) buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size) vec_obs = [ModelUtils.list_to_tensor(buffer["vector_obs"])] act_masks = ModelUtils.list_to_tensor(buffer["action_mask"]) vis_obs = [] for idx, _ in enumerate( policy.actor_critic.network_body.visual_processors): vis_ob = ModelUtils.list_to_tensor(buffer["visual_obs%d" % idx]) vis_obs.append(vis_ob) memories = [ ModelUtils.list_to_tensor(buffer["memory"][i]) for i in range(0, len(buffer["memory"]), policy.sequence_length) ] if len(memories) > 0: memories = torch.stack(memories).unsqueeze(0) ( sampled_actions, clipped_actions, log_probs, entropies, memories, ) = policy.sample_actions( vec_obs, vis_obs, masks=act_masks, memories=memories, seq_len=policy.sequence_length, all_log_probs=not policy.use_continuous_act, ) if discrete: assert log_probs.shape == ( 64, sum(policy.behavior_spec.action_spec.discrete_branches), ) else: assert log_probs.shape == ( 64, policy.behavior_spec.action_spec.continuous_size) assert clipped_actions.shape == ( 64, policy.behavior_spec.action_spec.continuous_size, ) assert entropies.shape == (64, ) if rnn: assert memories.shape == (1, 1, policy.m_size)
def get_probs_and_entropy( action_list: List[torch.Tensor], dists: List[DistInstance] ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: log_probs_list = [] all_probs_list = [] entropies_list = [] for action, action_dist in zip(action_list, dists): log_prob = action_dist.log_prob(action) log_probs_list.append(log_prob) entropies_list.append(action_dist.entropy()) if isinstance(action_dist, DiscreteDistInstance): all_probs_list.append(action_dist.all_log_prob()) log_probs = torch.stack(log_probs_list, dim=-1) entropies = torch.stack(entropies_list, dim=-1) if not all_probs_list: log_probs = log_probs.squeeze(-1) entropies = entropies.squeeze(-1) all_probs = None else: all_probs = torch.cat(all_probs_list, dim=-1) return log_probs, entropies, all_probs
def _get_masks_from_nans(self, obs_tensors: List[torch.Tensor]) -> torch.Tensor: """ Get attention masks by grabbing an arbitrary obs across all the agents Since these are raw obs, the padded values are still NaN """ only_first_obs = [_all_obs[0] for _all_obs in obs_tensors] # Just get the first element in each obs regardless of its dimension. This will speed up # searching for NaNs. only_first_obs_flat = torch.stack( [_obs.flatten(start_dim=1)[:, 0] for _obs in only_first_obs], dim=1 ) # Get the mask from NaNs attn_mask = only_first_obs_flat.isnan().float() return attn_mask
def sample_actions( self, vec_obs: List[torch.Tensor], vis_obs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, seq_len: int = 1, all_log_probs: bool = False, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ :param vec_obs: List of vector observations. :param vis_obs: List of visual observations. :param masks: Loss masks for RNN, else None. :param memories: Input memories when using RNN, else None. :param seq_len: Sequence length when using RNN. :param all_log_probs: Returns (for discrete actions) a tensor of log probs, one for each action. :return: Tuple of actions, actions clipped to -1, 1, log probabilities (dependent on all_log_probs), entropies, and output memories, all as Torch Tensors. """ if memories is None: dists, memories = self.actor_critic.get_dists( vec_obs, vis_obs, masks, memories, seq_len) else: # If we're using LSTM. we need to execute the values to get the critic memories dists, _, memories = self.actor_critic.get_dist_and_value( vec_obs, vis_obs, masks, memories, seq_len) action_list = self.actor_critic.sample_action(dists) log_probs, entropies, all_logs = ModelUtils.get_probs_and_entropy( action_list, dists) actions = torch.stack(action_list, dim=-1) if self.use_continuous_act: actions = actions[:, :, 0] else: actions = actions[:, 0, :] # Use the sum of entropy across actions, not the mean entropy_sum = torch.sum(entropies, dim=1) if self._clip_action and self.use_continuous_act: clipped_action = torch.clamp(actions, -3, 3) / 3 else: clipped_action = actions return ( actions, clipped_action, all_logs if all_log_probs else log_probs, entropy_sum, memories, )
def _condense_q_streams( self, q_output: Dict[str, torch.Tensor], discrete_actions: torch.Tensor) -> Dict[str, torch.Tensor]: condensed_q_output = {} onehot_actions = ModelUtils.actions_to_onehot(discrete_actions, self.act_size) for key, item in q_output.items(): branched_q = ModelUtils.break_into_branches(item, self.act_size) only_action_qs = torch.stack([ torch.sum(_act * _q, dim=1, keepdim=True) for _act, _q in zip(onehot_actions, branched_q) ]) condensed_q_output[key] = torch.mean(only_action_qs, dim=0) return condensed_q_output
def test_simple_transformer_training(): np.random.seed(1336) torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 entity_embeddings = EntityEmbeddings(size, [size], [n_k], embedding_size) transformer = ResidualSelfAttention(embedding_size, [n_k]) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam(list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001) batch_size = 200 point_range = 3 init_error = -1.0 for _ in range(250): center = torch.rand((batch_size, size)) * point_range * 2 - point_range key = torch.rand( (batch_size, n_k, size)) * point_range * 2 - point_range with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum((center.reshape( (batch_size, 1, size)) - key)**2, dim=2) argmin = torch.argmin(distance, dim=1) target = [] for i in range(batch_size): target += [key[i, argmin[i], :]] target = torch.stack(target, dim=0) target = target.detach() embeddings = entity_embeddings(center, [key]) masks = EntityEmbeddings.get_masks([key]) prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target)**2, dim=1) error = torch.mean(error) / 2 if init_error == -1.0: init_error = error.item() else: assert error.item() < init_error print(error.item()) optimizer.zero_grad() error.backward() optimizer.step() assert error.item() < 0.3
def test_evaluate_actions(rnn, visual, discrete): policy = create_policy_mock(TrainerSettings(), use_rnn=rnn, use_discrete=discrete, use_visual=visual) buffer = mb.simulate_rollout(64, policy.behavior_spec, memory_size=policy.m_size) vec_obs = [ModelUtils.list_to_tensor(buffer["vector_obs"])] act_masks = ModelUtils.list_to_tensor(buffer["action_mask"]) if policy.use_continuous_act: actions = ModelUtils.list_to_tensor(buffer["actions"]).unsqueeze(-1) else: actions = ModelUtils.list_to_tensor(buffer["actions"], dtype=torch.long) vis_obs = [] for idx, _ in enumerate( policy.actor_critic.network_body.visual_processors): vis_ob = ModelUtils.list_to_tensor(buffer["visual_obs%d" % idx]) vis_obs.append(vis_ob) memories = [ ModelUtils.list_to_tensor(buffer["memory"][i]) for i in range(0, len(buffer["memory"]), policy.sequence_length) ] if len(memories) > 0: memories = torch.stack(memories).unsqueeze(0) log_probs, entropy, values = policy.evaluate_actions( vec_obs, vis_obs, masks=act_masks, actions=actions, memories=memories, seq_len=policy.sequence_length, ) if discrete: _size = policy.behavior_spec.action_spec.discrete_size else: _size = policy.behavior_spec.action_spec.continuous_size assert log_probs.shape == (64, _size) assert entropy.shape == (64, ) for val in values.values(): assert val.shape == (64, )
def _behavioral_cloning_loss(self, selected_actions, log_probs, expert_actions): if self.policy.use_continuous_act: bc_loss = torch.nn.functional.mse_loss(selected_actions, expert_actions) else: log_prob_branches = ModelUtils.break_into_branches( log_probs, self.policy.act_size) bc_loss = torch.mean( torch.stack([ torch.sum( -torch.nn.functional.log_softmax( log_prob_branch, dim=1) * expert_actions_branch, dim=1, ) for log_prob_branch, expert_actions_branch in zip( log_prob_branches, expert_actions) ])) return bc_loss
def test_predict_closest_training(): np.random.seed(1336) torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 entity_embeddings = EntityEmbedding(size, n_k, embedding_size) entity_embeddings.add_self_embedding(size) transformer = ResidualSelfAttention(embedding_size, n_k) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam( list(entity_embeddings.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, weight_decay=1e-6, ) batch_size = 200 for _ in range(200): center = torch.rand((batch_size, size)) key = torch.rand((batch_size, n_k, size)) with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum((center.reshape( (batch_size, 1, size)) - key)**2, dim=2) argmin = torch.argmin(distance, dim=1) target = [] for i in range(batch_size): target += [key[i, argmin[i], :]] target = torch.stack(target, dim=0) target = target.detach() embeddings = entity_embeddings(center, key) masks = get_zero_entities_mask([key]) prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target)**2, dim=1) error = torch.mean(error) / 2 print(error.item()) optimizer.zero_grad() error.backward() optimizer.step() assert error.item() < 0.02
def sac_entropy_loss( self, log_probs: ActionLogProbs, loss_masks: torch.Tensor ) -> torch.Tensor: _cont_ent_coef, _disc_ent_coef = ( self._log_ent_coef.continuous, self._log_ent_coef.discrete, ) entropy_loss = 0 if self._action_spec.discrete_size > 0: with torch.no_grad(): # Break continuous into separate branch disc_log_probs = log_probs.all_discrete_tensor branched_per_action_ent = ModelUtils.break_into_branches( disc_log_probs * disc_log_probs.exp(), self._action_spec.discrete_branches, ) target_current_diff_branched = torch.stack( [ torch.sum(_lp, axis=1, keepdim=True) + _te for _lp, _te in zip( branched_per_action_ent, self.target_entropy.discrete ) ], axis=1, ) target_current_diff = torch.squeeze( target_current_diff_branched, axis=2 ) entropy_loss += -1 * ModelUtils.masked_mean( torch.mean(_disc_ent_coef * target_current_diff, axis=1), loss_masks ) if self._action_spec.continuous_size > 0: with torch.no_grad(): cont_log_probs = log_probs.continuous_tensor target_current_diff = torch.sum( cont_log_probs + self.target_entropy.continuous, dim=1 ) # We update all the _cont_ent_coef as one block entropy_loss += -1 * ModelUtils.masked_mean( _cont_ent_coef * target_current_diff, loss_masks ) return entropy_loss
def test_all_masking(mask_value): # We make sure that a mask of all zeros or all ones will not trigger an error np.random.seed(1336) torch.manual_seed(1336) size, n_k, = 3, 5 embedding_size = 64 entity_embeddings = EntityEmbedding(size, n_k, embedding_size) entity_embeddings.add_self_embedding(size) transformer = ResidualSelfAttention(embedding_size, n_k) l_layer = linear_layer(embedding_size, size) optimizer = torch.optim.Adam( list(entity_embeddings.parameters()) + list(transformer.parameters()) + list(l_layer.parameters()), lr=0.001, weight_decay=1e-6, ) batch_size = 20 for _ in range(5): center = torch.rand((batch_size, size)) key = torch.rand((batch_size, n_k, size)) with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum((center.reshape( (batch_size, 1, size)) - key)**2, dim=2) argmin = torch.argmin(distance, dim=1) target = [] for i in range(batch_size): target += [key[i, argmin[i], :]] target = torch.stack(target, dim=0) target = target.detach() embeddings = entity_embeddings(center, key) masks = [torch.ones_like(key[:, :, 0]) * mask_value] prediction = transformer.forward(embeddings, masks) prediction = l_layer(prediction) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target)**2, dim=1) error = torch.mean(error) / 2 optimizer.zero_grad() error.backward() optimizer.step()
def test_multi_head_attention_training(): np.random.seed(1336) torch.manual_seed(1336) size, n_h, n_k, n_q = 3, 10, 5, 1 embedding_size = 64 mha = MultiHeadAttention(size, size, size, size, n_h, embedding_size) optimizer = torch.optim.Adam(mha.parameters(), lr=0.001) batch_size = 200 point_range = 3 init_error = -1.0 for _ in range(50): query = torch.rand( (batch_size, n_q, size)) * point_range * 2 - point_range key = torch.rand( (batch_size, n_k, size)) * point_range * 2 - point_range value = key with torch.no_grad(): # create the target : The key closest to the query in euclidean distance distance = torch.sum((query - key)**2, dim=2) argmin = torch.argmin(distance, dim=1) target = [] for i in range(batch_size): target += [key[i, argmin[i], :]] target = torch.stack(target, dim=0) target = target.detach() prediction, _ = mha.forward(query, key, value) prediction = prediction.reshape((batch_size, size)) error = torch.mean((prediction - target)**2, dim=1) error = torch.mean(error) / 2 if init_error == -1.0: init_error = error.item() else: assert error.item() < init_error print(error.item()) optimizer.zero_grad() error.backward() optimizer.step() assert error.item() < 0.5
def forward( self, vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, int, int, int, int]: """ Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs. """ dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1) action_list = self.sample_action(dists) sampled_actions = torch.stack(action_list, dim=-1) if self.act_type == ActionType.CONTINUOUS: action_out = sampled_actions else: action_out = dists[0].all_log_prob() return ( action_out, self.version_number, torch.Tensor([self.network_body.memory_size]), self.is_continuous_int, self.act_size_vector, )
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: """ Performs update on model. :param batch: Batch of experiences. :param num_sequences: Number of sequences to process. :return: Results of update. """ # Get decayed parameters decay_lr = self.decay_learning_rate.get_value( self.policy.get_current_step()) decay_eps = self.decay_epsilon.get_value( self.policy.get_current_step()) decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) returns = {} old_values = {} for name in self.reward_signals: old_values[name] = ModelUtils.list_to_tensor( batch[RewardSignalUtil.value_estimates_key(name)]) returns[name] = ModelUtils.list_to_tensor( batch[RewardSignalUtil.returns_key(name)]) n_obs = len(self.policy.behavior_spec.observation_specs) current_obs = ObsUtil.from_buffer(batch, n_obs) # Convert to tensors current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK]) actions = AgentAction.from_buffer(batch) memories = [ ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i]) for i in range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length) ] if len(memories) > 0: memories = torch.stack(memories).unsqueeze(0) # Get value memories value_memories = [ ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i]) for i in range(0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length) ] if len(value_memories) > 0: value_memories = torch.stack(value_memories).unsqueeze(0) log_probs, entropy = self.policy.evaluate_actions( current_obs, masks=act_masks, actions=actions, memories=memories, seq_len=self.policy.sequence_length, ) values, _ = self.critic.critic_pass( current_obs, memories=value_memories, sequence_length=self.policy.sequence_length, ) old_log_probs = ActionLogProbs.from_buffer(batch).flatten() log_probs = log_probs.flatten() loss_masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool) value_loss = self.ppo_value_loss(values, old_values, returns, decay_eps, loss_masks) policy_loss = self.ppo_policy_loss( ModelUtils.list_to_tensor(batch[BufferKey.ADVANTAGES]), log_probs, old_log_probs, loss_masks, ) loss = (policy_loss + 0.5 * value_loss - decay_bet * ModelUtils.masked_mean(entropy, loss_masks)) # Set optimizer learning rate ModelUtils.update_learning_rate(self.optimizer, decay_lr) self.optimizer.zero_grad() loss.backward() self.optimizer.step() update_stats = { # NOTE: abs() is not technically correct, but matches the behavior in TensorFlow. # TODO: After PyTorch is default, change to something more correct. "Losses/Policy Loss": torch.abs(policy_loss).item(), "Losses/Value Loss": value_loss.item(), "Policy/Learning Rate": decay_lr, "Policy/Epsilon": decay_eps, "Policy/Beta": decay_bet, } for reward_provider in self.reward_signals.values(): update_stats.update(reward_provider.update(batch)) return update_stats
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: """ Updates model using buffer. :param num_sequences: Number of trajectories in batch. :param batch: Experience mini-batch. :param update_target: Whether or not to update target value network :param reward_signal_batches: Minibatches to use for updating the reward signals, indexed by name. If none, don't update the reward signals. :return: Output from update process. """ rewards = {} for name in self.reward_signals: rewards[name] = ModelUtils.list_to_tensor(batch[f"{name}_rewards"]) n_obs = len(self.policy.behavior_spec.sensor_specs) current_obs = ObsUtil.from_buffer(batch, n_obs) # Convert to tensors current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] next_obs = ObsUtil.from_buffer_next(batch, n_obs) # Convert to tensors next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs] act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) actions = AgentAction.from_dict(batch) memories_list = [ ModelUtils.list_to_tensor(batch["memory"][i]) for i in range( 0, len(batch["memory"]), self.policy.sequence_length) ] # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. offset = 1 if self.policy.sequence_length > 1 else 0 next_memories_list = [ ModelUtils.list_to_tensor( batch["memory"][i] [self.policy.m_size // 2:]) # only pass value part of memory to target network for i in range(offset, len(batch["memory"]), self.policy.sequence_length) ] if len(memories_list) > 0: memories = torch.stack(memories_list).unsqueeze(0) next_memories = torch.stack(next_memories_list).unsqueeze(0) else: memories = None next_memories = None # Q network memories are 0'ed out, since we don't have them during inference. q_memories = (torch.zeros_like(next_memories) if next_memories is not None else None) # Copy normalizers from policy self.value_network.q1_network.network_body.copy_normalization( self.policy.actor_critic.network_body) self.value_network.q2_network.network_body.copy_normalization( self.policy.actor_critic.network_body) self.target_network.network_body.copy_normalization( self.policy.actor_critic.network_body) ( sampled_actions, log_probs, _, value_estimates, _, ) = self.policy.actor_critic.get_action_stats_and_value( current_obs, masks=act_masks, memories=memories, sequence_length=self.policy.sequence_length, ) cont_sampled_actions = sampled_actions.continuous_tensor cont_actions = actions.continuous_tensor q1p_out, q2p_out = self.value_network( current_obs, cont_sampled_actions, memories=q_memories, sequence_length=self.policy.sequence_length, q2_grad=False, ) q1_out, q2_out = self.value_network( current_obs, cont_actions, memories=q_memories, sequence_length=self.policy.sequence_length, ) if self._action_spec.discrete_size > 0: disc_actions = actions.discrete_tensor q1_stream = self._condense_q_streams(q1_out, disc_actions) q2_stream = self._condense_q_streams(q2_out, disc_actions) else: q1_stream, q2_stream = q1_out, q2_out with torch.no_grad(): target_values, _ = self.target_network( next_obs, memories=next_memories, sequence_length=self.policy.sequence_length, ) masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) dones = ModelUtils.list_to_tensor(batch["done"]) q1_loss, q2_loss = self.sac_q_loss(q1_stream, q2_stream, target_values, dones, rewards, masks) value_loss = self.sac_value_loss(log_probs, value_estimates, q1p_out, q2p_out, masks) policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks) entropy_loss = self.sac_entropy_loss(log_probs, masks) total_value_loss = q1_loss + q2_loss + value_loss decay_lr = self.decay_learning_rate.get_value( self.policy.get_current_step()) ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() ModelUtils.update_learning_rate(self.value_optimizer, decay_lr) self.value_optimizer.zero_grad() total_value_loss.backward() self.value_optimizer.step() ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr) self.entropy_optimizer.zero_grad() entropy_loss.backward() self.entropy_optimizer.step() # Update target network ModelUtils.soft_update(self.policy.actor_critic.critic, self.target_network, self.tau) update_stats = { "Losses/Policy Loss": policy_loss.item(), "Losses/Value Loss": value_loss.item(), "Losses/Q1 Loss": q1_loss.item(), "Losses/Q2 Loss": q2_loss.item(), "Policy/Discrete Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef.discrete)).item(), "Policy/Continuous Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef.continuous)).item(), "Policy/Learning Rate": decay_lr, } return update_stats
def sac_value_loss( self, log_probs: ActionLogProbs, values: Dict[str, torch.Tensor], q1p_out: Dict[str, torch.Tensor], q2p_out: Dict[str, torch.Tensor], loss_masks: torch.Tensor, ) -> torch.Tensor: min_policy_qs = {} with torch.no_grad(): _cont_ent_coef = self._log_ent_coef.continuous.exp() _disc_ent_coef = self._log_ent_coef.discrete.exp() for name in values.keys(): if self._action_spec.discrete_size <= 0: min_policy_qs[name] = torch.min(q1p_out[name], q2p_out[name]) else: disc_action_probs = log_probs.all_discrete_tensor.exp() _branched_q1p = ModelUtils.break_into_branches( q1p_out[name] * disc_action_probs, self._action_spec.discrete_branches, ) _branched_q2p = ModelUtils.break_into_branches( q2p_out[name] * disc_action_probs, self._action_spec.discrete_branches, ) _q1p_mean = torch.mean( torch.stack([ torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q1p ]), dim=0, ) _q2p_mean = torch.mean( torch.stack([ torch.sum(_br, dim=1, keepdim=True) for _br in _branched_q2p ]), dim=0, ) min_policy_qs[name] = torch.min(_q1p_mean, _q2p_mean) value_losses = [] if self._action_spec.discrete_size <= 0: for name in values.keys(): with torch.no_grad(): v_backup = min_policy_qs[name] - torch.sum( _cont_ent_coef * log_probs.continuous_tensor, dim=1) value_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(values[name], v_backup), loss_masks) value_losses.append(value_loss) else: disc_log_probs = log_probs.all_discrete_tensor branched_per_action_ent = ModelUtils.break_into_branches( disc_log_probs * disc_log_probs.exp(), self._action_spec.discrete_branches, ) # We have to do entropy bonus per action branch branched_ent_bonus = torch.stack([ torch.sum(_disc_ent_coef[i] * _lp, dim=1, keepdim=True) for i, _lp in enumerate(branched_per_action_ent) ]) for name in values.keys(): with torch.no_grad(): v_backup = min_policy_qs[name] - torch.mean( branched_ent_bonus, axis=0) # Add continuous entropy bonus to minimum Q if self._action_spec.continuous_size > 0: v_backup += torch.sum( _cont_ent_coef * log_probs.continuous_tensor, dim=1, keepdim=True, ) value_loss = 0.5 * ModelUtils.masked_mean( torch.nn.functional.mse_loss(values[name], v_backup.squeeze()), loss_masks, ) value_losses.append(value_loss) value_loss = torch.mean(torch.stack(value_losses)) if torch.isinf(value_loss).any() or torch.isnan(value_loss).any(): raise UnityTrainerException("Inf found") return value_loss
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: """ Performs update on model. :param batch: Batch of experiences. :param num_sequences: Number of sequences to process. :return: Results of update. """ # Get decayed parameters decay_lr = self.decay_learning_rate.get_value( self.policy.get_current_step()) decay_eps = self.decay_epsilon.get_value( self.policy.get_current_step()) decay_bet = self.decay_beta.get_value(self.policy.get_current_step()) returns = {} old_values = {} for name in self.reward_signals: old_values[name] = ModelUtils.list_to_tensor( batch[f"{name}_value_estimates"]) returns[name] = ModelUtils.list_to_tensor(batch[f"{name}_returns"]) vec_obs = [ModelUtils.list_to_tensor(batch["vector_obs"])] act_masks = ModelUtils.list_to_tensor(batch["action_mask"]) if self.policy.use_continuous_act: actions = ModelUtils.list_to_tensor( batch["actions_pre"]).unsqueeze(-1) else: actions = ModelUtils.list_to_tensor(batch["actions"], dtype=torch.long) memories = [ ModelUtils.list_to_tensor(batch["memory"][i]) for i in range( 0, len(batch["memory"]), self.policy.sequence_length) ] if len(memories) > 0: memories = torch.stack(memories).unsqueeze(0) if self.policy.use_vis_obs: vis_obs = [] for idx, _ in enumerate( self.policy.actor_critic.network_body.visual_processors): vis_ob = ModelUtils.list_to_tensor(batch["visual_obs%d" % idx]) vis_obs.append(vis_ob) else: vis_obs = [] log_probs, entropy, values = self.policy.evaluate_actions( vec_obs, vis_obs, masks=act_masks, actions=actions, memories=memories, seq_len=self.policy.sequence_length, ) loss_masks = ModelUtils.list_to_tensor(batch["masks"], dtype=torch.bool) value_loss = self.ppo_value_loss(values, old_values, returns, decay_eps, loss_masks) policy_loss = self.ppo_policy_loss( ModelUtils.list_to_tensor(batch["advantages"]), log_probs, ModelUtils.list_to_tensor(batch["action_probs"]), loss_masks, ) loss = (policy_loss + 0.5 * value_loss - decay_bet * ModelUtils.masked_mean(entropy, loss_masks)) # Set optimizer learning rate ModelUtils.update_learning_rate(self.optimizer, decay_lr) self.optimizer.zero_grad() loss.backward() self.optimizer.step() update_stats = { # NOTE: abs() is not technically correct, but matches the behavior in TensorFlow. # TODO: After PyTorch is default, change to something more correct. "Losses/Policy Loss": torch.abs(policy_loss).item(), "Losses/Value Loss": value_loss.item(), "Policy/Learning Rate": decay_lr, "Policy/Epsilon": decay_eps, "Policy/Beta": decay_bet, } for reward_provider in self.reward_signals.values(): update_stats.update(reward_provider.update(batch)) return update_stats
def update(self, batch: AgentBuffer, num_sequences: int) -> Dict[str, float]: """ Updates model using buffer. :param num_sequences: Number of trajectories in batch. :param batch: Experience mini-batch. :param update_target: Whether or not to update target value network :param reward_signal_batches: Minibatches to use for updating the reward signals, indexed by name. If none, don't update the reward signals. :return: Output from update process. """ rewards = {} for name in self.reward_signals: rewards[name] = ModelUtils.list_to_tensor( batch[RewardSignalUtil.rewards_key(name)]) n_obs = len(self.policy.behavior_spec.observation_specs) current_obs = ObsUtil.from_buffer(batch, n_obs) # Convert to tensors current_obs = [ModelUtils.list_to_tensor(obs) for obs in current_obs] next_obs = ObsUtil.from_buffer_next(batch, n_obs) # Convert to tensors next_obs = [ModelUtils.list_to_tensor(obs) for obs in next_obs] act_masks = ModelUtils.list_to_tensor(batch[BufferKey.ACTION_MASK]) actions = AgentAction.from_buffer(batch) memories_list = [ ModelUtils.list_to_tensor(batch[BufferKey.MEMORY][i]) for i in range(0, len(batch[BufferKey.MEMORY]), self.policy.sequence_length) ] # LSTM shouldn't have sequence length <1, but stop it from going out of the index if true. value_memories_list = [ ModelUtils.list_to_tensor(batch[BufferKey.CRITIC_MEMORY][i]) for i in range(0, len(batch[BufferKey.CRITIC_MEMORY]), self.policy.sequence_length) ] if len(memories_list) > 0: memories = torch.stack(memories_list).unsqueeze(0) value_memories = torch.stack(value_memories_list).unsqueeze(0) else: memories = None value_memories = None # Q and V network memories are 0'ed out, since we don't have them during inference. q_memories = (torch.zeros_like(value_memories) if value_memories is not None else None) # Copy normalizers from policy self.q_network.q1_network.network_body.copy_normalization( self.policy.actor.network_body) self.q_network.q2_network.network_body.copy_normalization( self.policy.actor.network_body) self.target_network.network_body.copy_normalization( self.policy.actor.network_body) self._critic.network_body.copy_normalization( self.policy.actor.network_body) sampled_actions, log_probs, _, _, = self.policy.actor.get_action_and_stats( current_obs, masks=act_masks, memories=memories, sequence_length=self.policy.sequence_length, ) value_estimates, _ = self._critic.critic_pass( current_obs, value_memories, sequence_length=self.policy.sequence_length) cont_sampled_actions = sampled_actions.continuous_tensor cont_actions = actions.continuous_tensor q1p_out, q2p_out = self.q_network( current_obs, cont_sampled_actions, memories=q_memories, sequence_length=self.policy.sequence_length, q2_grad=False, ) q1_out, q2_out = self.q_network( current_obs, cont_actions, memories=q_memories, sequence_length=self.policy.sequence_length, ) if self._action_spec.discrete_size > 0: disc_actions = actions.discrete_tensor q1_stream = self._condense_q_streams(q1_out, disc_actions) q2_stream = self._condense_q_streams(q2_out, disc_actions) else: q1_stream, q2_stream = q1_out, q2_out with torch.no_grad(): # Since we didn't record the next value memories, evaluate one step in the critic to # get them. if value_memories is not None: # Get the first observation in each sequence just_first_obs = [ _obs[::self.policy.sequence_length] for _obs in current_obs ] _, next_value_memories = self._critic.critic_pass( just_first_obs, value_memories, sequence_length=1) else: next_value_memories = None target_values, _ = self.target_network( next_obs, memories=next_value_memories, sequence_length=self.policy.sequence_length, ) masks = ModelUtils.list_to_tensor(batch[BufferKey.MASKS], dtype=torch.bool) dones = ModelUtils.list_to_tensor(batch[BufferKey.DONE]) q1_loss, q2_loss = self.sac_q_loss(q1_stream, q2_stream, target_values, dones, rewards, masks) value_loss = self.sac_value_loss(log_probs, value_estimates, q1p_out, q2p_out, masks) policy_loss = self.sac_policy_loss(log_probs, q1p_out, masks) entropy_loss = self.sac_entropy_loss(log_probs, masks) total_value_loss = q1_loss + q2_loss if self.policy.shared_critic: policy_loss += value_loss else: total_value_loss += value_loss decay_lr = self.decay_learning_rate.get_value( self.policy.get_current_step()) ModelUtils.update_learning_rate(self.policy_optimizer, decay_lr) self.policy_optimizer.zero_grad() policy_loss.backward() self.policy_optimizer.step() ModelUtils.update_learning_rate(self.value_optimizer, decay_lr) self.value_optimizer.zero_grad() total_value_loss.backward() self.value_optimizer.step() ModelUtils.update_learning_rate(self.entropy_optimizer, decay_lr) self.entropy_optimizer.zero_grad() entropy_loss.backward() self.entropy_optimizer.step() # Update target network ModelUtils.soft_update(self._critic, self.target_network, self.tau) update_stats = { "Losses/Policy Loss": policy_loss.item(), "Losses/Value Loss": value_loss.item(), "Losses/Q1 Loss": q1_loss.item(), "Losses/Q2 Loss": q2_loss.item(), "Policy/Discrete Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef.discrete)).item(), "Policy/Continuous Entropy Coeff": torch.mean(torch.exp(self._log_ent_coef.continuous)).item(), "Policy/Learning Rate": decay_lr, } return update_stats