def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: num_entities = self.entity_num_max_elements if num_entities < 0: if exporting_to_onnx.is_exporting(): raise UnityTrainerException( "Trying to export an attention mechanism that doesn't have a set max \ number of elements.") num_entities = entities.shape[1] if exporting_to_onnx.is_exporting(): # When exporting to ONNX, we want to transpose the entities. This is # because ONNX only support input in NCHW (channel first) format. # Barracuda also expect to get data in NCHW. entities = torch.transpose(entities, 2, 1).reshape(-1, num_entities, self.entity_size) if self.self_size > 0: expanded_self = x_self.reshape(-1, 1, self.self_size) expanded_self = torch.cat([expanded_self] * num_entities, dim=1) # Concatenate all observations with self entities = torch.cat([expanded_self, entities], dim=2) # Encode entities encoded_entities = self.self_ent_encoder(entities) return encoded_entities
def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Tensor: """ Gets the tensors corresponding to the output of the policy network to be used for inference. Called by the Actor's forward call. :params inputs: The encoding from the network body :params masks: Action masks for discrete actions :return: A tuple of torch tensors corresponding to the inference output """ dists = self._get_dists(inputs, masks) continuous_out, discrete_out, action_out_deprecated = None, None, None if self.action_spec.continuous_size > 0 and dists.continuous is not None: continuous_out = dists.continuous.exported_model_output() action_out_deprecated = dists.continuous.exported_model_output() if self._clip_action_on_export: continuous_out = torch.clamp(continuous_out, -3, 3) / 3 action_out_deprecated = torch.clamp(action_out_deprecated, -3, 3) / 3 if self.action_spec.discrete_size > 0 and dists.discrete is not None: discrete_out_list = [ discrete_dist.exported_model_output() for discrete_dist in dists.discrete ] discrete_out = torch.cat(discrete_out_list, dim=1) action_out_deprecated = torch.cat(discrete_out_list, dim=1) # deprecated action field does not support hybrid action if self.action_spec.continuous_size > 0 and self.action_spec.discrete_size > 0: action_out_deprecated = None return continuous_out, discrete_out, action_out_deprecated
def forward(self, x_self: torch.Tensor, entities: List[torch.Tensor]) -> Tuple[torch.Tensor, int]: if self.concat_self: # Concatenate all observations with self self_and_ent: List[torch.Tensor] = [] for num_entities, ent in zip(self.entity_num_max_elements, entities): if num_entities < 0: if exporting_to_onnx.is_exporting(): raise UnityTrainerException( "Trying to export an attention mechanism that doesn't have a set max \ number of elements.") num_entities = ent.shape[1] expanded_self = x_self.reshape(-1, 1, self.self_size) expanded_self = torch.cat([expanded_self] * num_entities, dim=1) self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) else: self_and_ent = entities # Encode and concatenate entites encoded_entities = torch.cat( [ ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent) ], dim=1, ) encoded_entities = self.embedding_norm(encoded_entities) return encoded_entities
def forward( self, inputs: List[torch.Tensor], actions: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor]: encodes = [] for idx, processor in enumerate(self.processors): obs_input = inputs[idx] processed_obs = processor(obs_input) encodes.append(processed_obs) if len(encodes) == 0: raise Exception("No valid inputs to network.") # Constants don't work in Barracuda if actions is not None: inputs = torch.cat(encodes + [actions], dim=-1) else: inputs = torch.cat(encodes, dim=-1) encoding = self.linear_encoder(inputs) if self.use_lstm: # Resize to (batch, sequence length, encoding size) encoding = encoding.reshape([-1, sequence_length, self.h_size]) encoding, memories = self.lstm(encoding, memories) encoding = encoding.reshape([-1, self.m_size // 2]) return encoding, memories
def forward( self, obs_only: List[List[torch.Tensor]], obs: List[List[torch.Tensor]], actions: List[AgentAction], memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor]: """ Returns sampled actions. If memory is enabled, return the memories as well. :param obs_only: Observations to be processed that do not have corresponding actions. These are encoded with the obs_encoder. :param obs: Observations to be processed that do have corresponding actions. After concatenation with actions, these are processed with obs_action_encoder. :param actions: After concatenation with obs, these are processed with obs_action_encoder. :param memories: If using memory, a Tensor of initial memories. :param sequence_length: If using memory, the sequence length. """ self_attn_masks = [] self_attn_inputs = [] concat_f_inp = [] if obs: obs_attn_mask = self._get_masks_from_nans(obs) obs = self._copy_and_remove_nans_from_obs(obs, obs_attn_mask) for inputs, action in zip(obs, actions): encoded = self.observation_encoder(inputs) cat_encodes = [ encoded, action.to_flat(self.action_spec.discrete_branches), ] concat_f_inp.append(torch.cat(cat_encodes, dim=1)) f_inp = torch.stack(concat_f_inp, dim=1) self_attn_masks.append(obs_attn_mask) self_attn_inputs.append(self.obs_action_encoder(None, f_inp)) concat_encoded_obs = [] if obs_only: obs_only_attn_mask = self._get_masks_from_nans(obs_only) obs_only = self._copy_and_remove_nans_from_obs( obs_only, obs_only_attn_mask) for inputs in obs_only: encoded = self.observation_encoder(inputs) concat_encoded_obs.append(encoded) g_inp = torch.stack(concat_encoded_obs, dim=1) self_attn_masks.append(obs_only_attn_mask) self_attn_inputs.append(self.obs_encoder(None, g_inp)) encoded_entity = torch.cat(self_attn_inputs, dim=1) encoded_state = self.self_attn(encoded_entity, self_attn_masks) encoding = self.linear_encoder(encoded_state) if self.use_lstm: # Resize to (batch, sequence length, encoding size) encoding = encoding.reshape([-1, sequence_length, self.h_size]) encoding, memories = self.lstm(encoding, memories) encoding = encoding.reshape([-1, self.m_size // 2]) return encoding, memories
def forward( self, inputs: List[torch.Tensor], actions: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor]: encodes = [] var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = [] for idx, processor in enumerate(self.processors): if not isinstance(processor, EntityEmbedding): # The input can be encoded without having to process other inputs obs_input = inputs[idx] processed_obs = processor(obs_input) encodes.append(processed_obs) else: var_len_processor_inputs.append((processor, inputs[idx])) if len(encodes) != 0: encoded_self = torch.cat(encodes, dim=1) input_exist = True else: input_exist = False if len(var_len_processor_inputs) > 0: # Some inputs need to be processed with a variable length encoder masks = get_zero_entities_mask( [p_i[1] for p_i in var_len_processor_inputs]) embeddings: List[torch.Tensor] = [] processed_self = self.x_self_encoder( encoded_self) if input_exist else None for processor, var_len_input in var_len_processor_inputs: embeddings.append(processor(processed_self, var_len_input)) qkv = torch.cat(embeddings, dim=1) attention_embedding = self.rsa(qkv, masks) if not input_exist: encoded_self = torch.cat([attention_embedding], dim=1) input_exist = True else: encoded_self = torch.cat([encoded_self, attention_embedding], dim=1) if not input_exist: raise Exception( "The trainer was unable to process any of the provided inputs. " "Make sure the trained agents has at least one sensor attached to them." ) if actions is not None: encoded_self = torch.cat([encoded_self, actions], dim=1) encoding = self.linear_encoder(encoded_self) if self.use_lstm: # Resize to (batch, sequence length, encoding size) encoding = encoding.reshape([-1, sequence_length, self.h_size]) encoding, memories = self.lstm(encoding, memories) encoding = encoding.reshape([-1, self.m_size // 2]) return encoding, memories
def generate_input_helper(pattern): _input = torch.zeros((batch_size, 0, size)) for i in range(len(pattern)): if i % 2 == 0: _input = torch.cat( [_input, torch.rand((batch_size, pattern[i], size))], dim=1) else: _input = torch.cat( [_input, torch.zeros((batch_size, pattern[i], size))], dim=1) return _input
def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor: """ Encode observations using a list of processors and an RSA. :param inputs: List of Tensors corresponding to a set of obs. :param processors: a ModuleList of the input processors to be applied to these obs. :param rsa: Optionally, an RSA to use for variable length obs. :param x_self_encoder: Optionally, an encoder to use for x_self (in this case, the non-variable inputs.). """ encodes = [] var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = [] for idx, processor in enumerate(self.processors): if not isinstance(processor, EntityEmbedding): # The input can be encoded without having to process other inputs obs_input = inputs[idx] processed_obs = processor(obs_input) encodes.append(processed_obs) else: var_len_processor_inputs.append((processor, inputs[idx])) if len(encodes) != 0: encoded_self = torch.cat(encodes, dim=1) input_exist = True else: input_exist = False if len(var_len_processor_inputs) > 0 and self.rsa is not None: # Some inputs need to be processed with a variable length encoder masks = get_zero_entities_mask( [p_i[1] for p_i in var_len_processor_inputs]) embeddings: List[torch.Tensor] = [] processed_self = (self.x_self_encoder(encoded_self) if input_exist and self.x_self_encoder is not None else None) for processor, var_len_input in var_len_processor_inputs: embeddings.append(processor(processed_self, var_len_input)) qkv = torch.cat(embeddings, dim=1) attention_embedding = self.rsa(qkv, masks) if not input_exist: encoded_self = torch.cat([attention_embedding], dim=1) input_exist = True else: encoded_self = torch.cat([encoded_self, attention_embedding], dim=1) if not input_exist: raise UnityTrainerException( "The trainer was unable to process any of the provided inputs. " "Make sure the trained agents has at least one sensor attached to them." ) return encoded_self
def to_flat(self, discrete_branches: List[int]) -> torch.Tensor: """ Flatten this AgentAction into a single torch Tensor of dimension (batch, num_continuous + num_one_hot_discrete). Discrete actions are converted into one-hot and concatenated with continuous actions. :param discrete_branches: List of sizes for discrete actions. :return: Tensor of flattened actions. """ # if there are any discrete actions, create one-hot if self.discrete_list is not None and self.discrete_list: discrete_oh = ModelUtils.actions_to_onehot(self.discrete_tensor, discrete_branches) discrete_oh = torch.cat(discrete_oh, dim=1) else: discrete_oh = torch.empty(0) return torch.cat([self.continuous_tensor, discrete_oh], dim=-1)
def get_action_stats_and_value( self, vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[ AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor ]: if self.use_lstm: # Use only the back half of memories for critic and actor actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) else: critic_mem = None actor_mem = None encoding, actor_mem_outs = self.network_body( vec_inputs, vis_inputs, memories=actor_mem, sequence_length=sequence_length ) action, log_probs, entropies = self.action_model(encoding, masks) value_outputs, critic_mem_outs = self.critic( vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length ) if self.use_lstm: mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1) else: mem_out = None return action, log_probs, entropies, value_outputs, mem_out
def get_goal_encoding(self, inputs: List[torch.Tensor]) -> torch.Tensor: """ Encode observations corresponding to goals using a list of processors. :param inputs: List of Tensors corresponding to a set of obs. """ encodes = [] for idx in self._goal_processor_indices: processor = self.processors[idx] if not isinstance(processor, EntityEmbedding): # The input can be encoded without having to process other inputs obs_input = inputs[idx] processed_obs = processor(obs_input) encodes.append(processed_obs) else: raise UnityTrainerException( "The one of the goals uses variable length observations. This use " "case is not supported." ) if len(encodes) != 0: encoded = torch.cat(encodes, dim=1) else: raise UnityTrainerException( "Trainer was unable to process any of the goals provided as input." ) return encoded
def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor: # Gather the maximum number of entities information mask = torch.cat(key_masks, dim=1) inp = self.embedding_norm(inp) # Feed to self attention query = self.fc_q(inp) # (b, n_q, emb) key = self.fc_k(inp) # (b, n_k, emb) value = self.fc_v(inp) # (b, n_k, emb) # Only use max num if provided if self.max_num_ent is not None: num_ent = self.max_num_ent else: num_ent = inp.shape[1] if exporting_to_onnx.is_exporting(): raise UnityTrainerException( "Trying to export an attention mechanism that doesn't have a set max \ number of elements.") output, _ = self.attention(query, key, value, num_ent, num_ent, mask) # Residual output = self.fc_out(output) + inp output = self.residual_norm(output) # Average Pooling numerator = torch.sum(output * (1 - mask).reshape(-1, num_ent, 1), dim=1) denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON output = numerator / denominator return output
def _get_probs_and_entropy( self, actions: AgentAction, dists: DistInstances) -> Tuple[ActionLogProbs, torch.Tensor]: """ Computes the log probabilites of the actions given distributions and entropies of the given distributions. :params actions: The AgentAction :params dists: The DistInstances tuple :return: An ActionLogProbs tuple and a torch tensor of the distribution entropies. """ entropies_list: List[torch.Tensor] = [] continuous_log_prob: Optional[torch.Tensor] = None discrete_log_probs: Optional[List[torch.Tensor]] = None all_discrete_log_probs: Optional[List[torch.Tensor]] = None # This checks None because mypy complains otherwise if dists.continuous is not None: continuous_log_prob = dists.continuous.log_prob( actions.continuous_tensor) entropies_list.append(dists.continuous.entropy()) if dists.discrete is not None: discrete_log_probs = [] all_discrete_log_probs = [] for discrete_action, discrete_dist in zip( actions.discrete_list, dists.discrete # type: ignore ): discrete_log_prob = discrete_dist.log_prob(discrete_action) entropies_list.append(discrete_dist.entropy()) discrete_log_probs.append(discrete_log_prob) all_discrete_log_probs.append(discrete_dist.all_log_prob()) action_log_probs = ActionLogProbs(continuous_log_prob, discrete_log_probs, all_discrete_log_probs) entropies = torch.cat(entropies_list, dim=1) return action_log_probs, entropies
def get_dist_and_value( self, vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]: if self.use_lstm: # Use only the back half of memories for critic and actor actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1) else: critic_mem = None actor_mem = None dists, actor_mem_outs = self.get_dists( vec_inputs, vis_inputs, memories=actor_mem, sequence_length=sequence_length, masks=masks, ) value_outputs, critic_mem_outs = self.critic( vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length) if self.use_lstm: mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1) else: mem_out = None return dists, value_outputs, mem_out
def compute_estimate(self, mini_batch: AgentBuffer, use_vail_noise: bool = False) -> torch.Tensor: """ Given a mini_batch, computes the estimate (How much the discriminator believes the data was sampled from the demonstration data). :param mini_batch: The AgentBuffer of data :param use_vail_noise: Only when using VAIL : If true, will sample the code, if false, will return the mean of the code. """ vec_inputs, vis_inputs = self.get_state_inputs(mini_batch) if self._settings.use_actions: actions = self.get_action_input(mini_batch) dones = torch.as_tensor(mini_batch["done"], dtype=torch.float).unsqueeze(1) action_inputs = torch.cat([actions, dones], dim=1) hidden, _ = self.encoder(vec_inputs, vis_inputs, action_inputs) else: hidden, _ = self.encoder(vec_inputs, vis_inputs) z_mu: Optional[torch.Tensor] = None if self._settings.use_vail: z_mu = self._z_mu_layer(hidden) hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise) estimate = self._estimator(hidden) return estimate, z_mu
def forward( self, vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], masks: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, int, int, int, int]: """ Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs. """ dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1) if self.action_spec.is_continuous(): action_list = self.sample_action(dists) action_out = torch.stack(action_list, dim=-1) if self._clip_action_on_export: action_out = torch.clamp(action_out, -3, 3) / 3 else: action_out = torch.cat([dist.all_log_prob() for dist in dists], dim=1) return ( action_out, self.version_number, torch.Tensor([self.network_body.memory_size]), self.is_continuous_int, self.act_size_vector, )
def predict_action(self, mini_batch: AgentBuffer) -> torch.Tensor: """ In the continuous case, returns the predicted action. In the discrete case, returns the logits. """ inverse_model_input = torch.cat((self.get_current_state(mini_batch), self.get_next_state(mini_batch)), dim=1) hidden = self.inverse_model_action_prediction(inverse_model_input) if self._policy_specs.is_action_continuous(): return hidden else: branches = ModelUtils.break_into_branches( hidden, self._policy_specs.discrete_action_branches) branches = [torch.softmax(b, dim=1) for b in branches] return torch.cat(branches, dim=1)
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: if self.self_size > 0: num_entities = self.entity_num_max_elements if num_entities < 0: if exporting_to_onnx.is_exporting(): raise UnityTrainerException( "Trying to export an attention mechanism that doesn't have a set max \ number of elements.") num_entities = entities.shape[1] expanded_self = x_self.reshape(-1, 1, self.self_size) expanded_self = torch.cat([expanded_self] * num_entities, dim=1) # Concatenate all observations with self entities = torch.cat([expanded_self, entities], dim=2) # Encode entities encoded_entities = self.self_ent_encoder(entities) return encoded_entities
def forward(self, input_tensor: torch.Tensor, goal_tensor: torch.Tensor) -> torch.Tensor: # type: ignore activation = torch.cat([input_tensor, goal_tensor], dim=-1) for layer in self.layers: if isinstance(layer, HyperNetwork): activation = layer(activation, goal_tensor) else: activation = layer(activation) return activation
def forward(self, action: AgentAction) -> torch.Tensor: """ Returns a tensor corresponding the flattened action :param action: An AgentAction object """ action_list: List[torch.Tensor] = [] if self._specs.continuous_size > 0: action_list.append(action.continuous_tensor) if self._specs.discrete_size > 0: flat_discrete = torch.cat( ModelUtils.actions_to_onehot( torch.as_tensor(action.discrete_tensor, dtype=torch.long), self._specs.discrete_branches, ), dim=1, ) action_list.append(flat_discrete) return torch.cat(action_list, dim=1)
def forward(self, input_tensor: torch.Tensor, memories: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: # We don't use torch.split here since it is not supported by Barracuda h0 = memories[:, :, :self.hidden_size] c0 = memories[:, :, self.hidden_size:] hidden = (h0, c0) lstm_out, hidden_out = self.lstm(input_tensor, hidden) output_mem = torch.cat(hidden_out, dim=-1) return lstm_out, output_mem
def forward( self, x_self: torch.Tensor, entities: List[torch.Tensor], key_masks: List[torch.Tensor], ) -> torch.Tensor: # Gather the maximum number of entities information if self.entities_num_max_elements is None: self.entities_num_max_elements = [] for ent in entities: self.entities_num_max_elements.append(ent.shape[1]) # Concatenate all observations with self self_and_ent: List[torch.Tensor] = [] for num_entities, ent in zip(self.entities_num_max_elements, entities): expanded_self = x_self.reshape(-1, 1, self.self_size) # .repeat( # 1, num_entities, 1 # ) expanded_self = torch.cat([expanded_self] * num_entities, dim=1) self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) # Generate the tensor that will serve as query, key and value to self attention qkv = torch.cat( [ ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent) ], dim=1, ) mask = torch.cat(key_masks, dim=1) # Feed to self attention max_num_ent = sum(self.entities_num_max_elements) output, _ = self.attention(qkv, qkv, qkv, mask, max_num_ent, max_num_ent) # Residual output = self.residual_layer(output) + qkv # Average Pooling numerator = torch.sum(output * (1 - mask).reshape(-1, max_num_ent, 1), dim=1) denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPISLON output = numerator / denominator # Residual between x_self and the output of the module output = self.x_self_residual_layer(torch.cat([output, x_self], dim=1)) return output
def forward(self, action: torch.Tensor) -> torch.Tensor: if self._specs.is_action_continuous(): return action else: return torch.cat( ModelUtils.actions_to_onehot( torch.as_tensor(action, dtype=torch.long), self._specs.discrete_action_branches, ), dim=1, )
def predict_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Uses the current state embedding and the action of the mini_batch to predict the next state embedding. """ actions = AgentAction.from_buffer(mini_batch) flattened_action = self._action_flattener.forward(actions) forward_model_input = torch.cat( (self.get_current_state(mini_batch), flattened_action), dim=1) return self.forward_model_next_state_prediction(forward_model_input)
def predict_action(self, mini_batch: AgentBuffer) -> ActionPredictionTuple: """ In the continuous case, returns the predicted action. In the discrete case, returns the logits. """ inverse_model_input = torch.cat((self.get_current_state(mini_batch), self.get_next_state(mini_batch)), dim=1) continuous_pred = None discrete_pred = None hidden = self.inverse_model_action_encoding(inverse_model_input) if self._action_spec.continuous_size > 0: continuous_pred = self.continuous_action_prediction(hidden) if self._action_spec.discrete_size > 0: raw_discrete_pred = self.discrete_action_prediction(hidden) branches = ModelUtils.break_into_branches( raw_discrete_pred, self._action_spec.discrete_branches) branches = [torch.softmax(b, dim=1) for b in branches] discrete_pred = torch.cat(branches, dim=1) return ActionPredictionTuple(continuous_pred, discrete_pred)
def test_visual_encoder_trains(vis_class, size): torch.manual_seed(0) image_size = (size, size, 1) batch = 100 inputs = torch.cat([ torch.zeros((batch, ) + image_size), torch.ones((batch, ) + image_size) ], dim=0) target = torch.cat([torch.zeros((batch, )), torch.ones((batch, ))], dim=0) enc = vis_class(image_size[0], image_size[1], image_size[2], 1) optimizer = torch.optim.Adam(enc.parameters(), lr=0.001) for _ in range(15): prediction = enc(inputs)[:, 0] loss = torch.mean((target - prediction)**2) optimizer.zero_grad() loss.backward() optimizer.step() assert loss.item() < 0.05
def predict_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor: """ Uses the current state embedding and the action of the mini_batch to predict the next state embedding. """ if self._policy_specs.is_action_continuous(): action = ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.float) else: action = torch.cat( ModelUtils.actions_to_onehot( ModelUtils.list_to_tensor(mini_batch["actions"], dtype=torch.long), self._policy_specs.discrete_action_branches, ), dim=1, ) forward_model_input = torch.cat( (self.get_current_state(mini_batch), action), dim=1) return self.forward_model_next_state_prediction(forward_model_input)
def compute_gradient_magnitude(self, policy_batch: AgentBuffer, expert_batch: AgentBuffer) -> torch.Tensor: """ Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp. for off-policy. Compute gradients w.r.t randomly interpolated input. """ policy_inputs = self.get_state_inputs(policy_batch) expert_inputs = self.get_state_inputs(expert_batch) interp_inputs = [] for policy_input, expert_input in zip(policy_inputs, expert_inputs): obs_epsilon = torch.rand(policy_input.shape) interp_input = obs_epsilon * policy_input + ( 1 - obs_epsilon) * expert_input interp_input.requires_grad = True # For gradient calculation interp_inputs.append(interp_input) if self._settings.use_actions: policy_action = self.get_action_input(policy_batch) expert_action = self.get_action_input(expert_batch) action_epsilon = torch.rand(policy_action.shape) policy_dones = torch.as_tensor(policy_batch[BufferKey.DONE], dtype=torch.float).unsqueeze(1) expert_dones = torch.as_tensor(expert_batch[BufferKey.DONE], dtype=torch.float).unsqueeze(1) dones_epsilon = torch.rand(policy_dones.shape) action_inputs = torch.cat( [ action_epsilon * policy_action + (1 - action_epsilon) * expert_action, dones_epsilon * policy_dones + (1 - dones_epsilon) * expert_dones, ], dim=1, ) action_inputs.requires_grad = True hidden, _ = self.encoder(interp_inputs, action_inputs) encoder_input = tuple(interp_inputs + [action_inputs]) else: hidden, _ = self.encoder(interp_inputs) encoder_input = tuple(interp_inputs) if self._settings.use_vail: use_vail_noise = True z_mu = self._z_mu_layer(hidden) hidden = z_mu + torch.randn_like( z_mu) * self._z_sigma * use_vail_noise estimate = self._estimator(hidden).squeeze(1).sum() gradient = torch.autograd.grad(estimate, encoder_input, create_graph=True)[0] # Norm's gradient could be NaN at 0. Use our own safe_norm safe_norm = (torch.sum(gradient**2, dim=1) + self.EPSILON).sqrt() gradient_mag = torch.mean((safe_norm - 1)**2) return gradient_mag
def forward(self, x_self: torch.Tensor, entities: List[torch.Tensor]) -> Tuple[torch.Tensor, int]: if self.concat_self: # Concatenate all observations with self self_and_ent: List[torch.Tensor] = [] for num_entities, ent in zip(self.entity_num_max_elements, entities): expanded_self = x_self.reshape(-1, 1, self.self_size) expanded_self = torch.cat([expanded_self] * num_entities, dim=1) self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) else: self_and_ent = entities # Encode and concatenate entites encoded_entities = torch.cat( [ ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent) ], dim=1, ) return encoded_entities
def forward(self, inputs: torch.Tensor) -> List[DistInstance]: mu = self.mu(inputs) if self.conditional_sigma: log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2) else: # Expand so that entropy matches batch size. Note that we're using # torch.cat here instead of torch.expand() becuase it is not supported in the # verified version of Barracuda (1.0.2). log_sigma = torch.cat([self.log_sigma] * inputs.shape[0], axis=0) if self.tanh_squash: return [TanhGaussianDistInstance(mu, torch.exp(log_sigma))] else: return [GaussianDistInstance(mu, torch.exp(log_sigma))]