Exemplo n.º 1
0
    def forward(self, x_self: torch.Tensor,
                entities: torch.Tensor) -> torch.Tensor:
        num_entities = self.entity_num_max_elements
        if num_entities < 0:
            if exporting_to_onnx.is_exporting():
                raise UnityTrainerException(
                    "Trying to export an attention mechanism that doesn't have a set max \
                    number of elements.")
            num_entities = entities.shape[1]

        if exporting_to_onnx.is_exporting():
            # When exporting to ONNX, we want to transpose the entities. This is
            # because ONNX only support input in NCHW (channel first) format.
            # Barracuda also expect to get data in NCHW.
            entities = torch.transpose(entities, 2,
                                       1).reshape(-1, num_entities,
                                                  self.entity_size)

        if self.self_size > 0:
            expanded_self = x_self.reshape(-1, 1, self.self_size)
            expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
            # Concatenate all observations with self
            entities = torch.cat([expanded_self, entities], dim=2)
        # Encode entities
        encoded_entities = self.self_ent_encoder(entities)
        return encoded_entities
Exemplo n.º 2
0
 def get_action_out(self, inputs: torch.Tensor, masks: torch.Tensor) -> torch.Tensor:
     """
     Gets the tensors corresponding to the output of the policy network to be used for
     inference. Called by the Actor's forward call.
     :params inputs: The encoding from the network body
     :params masks: Action masks for discrete actions
     :return: A tuple of torch tensors corresponding to the inference output
     """
     dists = self._get_dists(inputs, masks)
     continuous_out, discrete_out, action_out_deprecated = None, None, None
     if self.action_spec.continuous_size > 0 and dists.continuous is not None:
         continuous_out = dists.continuous.exported_model_output()
         action_out_deprecated = dists.continuous.exported_model_output()
         if self._clip_action_on_export:
             continuous_out = torch.clamp(continuous_out, -3, 3) / 3
             action_out_deprecated = torch.clamp(action_out_deprecated, -3, 3) / 3
     if self.action_spec.discrete_size > 0 and dists.discrete is not None:
         discrete_out_list = [
             discrete_dist.exported_model_output()
             for discrete_dist in dists.discrete
         ]
         discrete_out = torch.cat(discrete_out_list, dim=1)
         action_out_deprecated = torch.cat(discrete_out_list, dim=1)
     # deprecated action field does not support hybrid action
     if self.action_spec.continuous_size > 0 and self.action_spec.discrete_size > 0:
         action_out_deprecated = None
     return continuous_out, discrete_out, action_out_deprecated
Exemplo n.º 3
0
 def forward(self, x_self: torch.Tensor,
             entities: List[torch.Tensor]) -> Tuple[torch.Tensor, int]:
     if self.concat_self:
         # Concatenate all observations with self
         self_and_ent: List[torch.Tensor] = []
         for num_entities, ent in zip(self.entity_num_max_elements,
                                      entities):
             if num_entities < 0:
                 if exporting_to_onnx.is_exporting():
                     raise UnityTrainerException(
                         "Trying to export an attention mechanism that doesn't have a set max \
                         number of elements.")
                 num_entities = ent.shape[1]
             expanded_self = x_self.reshape(-1, 1, self.self_size)
             expanded_self = torch.cat([expanded_self] * num_entities,
                                       dim=1)
             self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
     else:
         self_and_ent = entities
         # Encode and concatenate entites
     encoded_entities = torch.cat(
         [
             ent_encoder(x)
             for ent_encoder, x in zip(self.ent_encoders, self_and_ent)
         ],
         dim=1,
     )
     encoded_entities = self.embedding_norm(encoded_entities)
     return encoded_entities
Exemplo n.º 4
0
    def forward(
        self,
        inputs: List[torch.Tensor],
        actions: Optional[torch.Tensor] = None,
        memories: Optional[torch.Tensor] = None,
        sequence_length: int = 1,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        encodes = []
        for idx, processor in enumerate(self.processors):
            obs_input = inputs[idx]
            processed_obs = processor(obs_input)
            encodes.append(processed_obs)

        if len(encodes) == 0:
            raise Exception("No valid inputs to network.")

        # Constants don't work in Barracuda
        if actions is not None:
            inputs = torch.cat(encodes + [actions], dim=-1)
        else:
            inputs = torch.cat(encodes, dim=-1)
        encoding = self.linear_encoder(inputs)

        if self.use_lstm:
            # Resize to (batch, sequence length, encoding size)
            encoding = encoding.reshape([-1, sequence_length, self.h_size])
            encoding, memories = self.lstm(encoding, memories)
            encoding = encoding.reshape([-1, self.m_size // 2])
        return encoding, memories
Exemplo n.º 5
0
    def forward(
        self,
        obs_only: List[List[torch.Tensor]],
        obs: List[List[torch.Tensor]],
        actions: List[AgentAction],
        memories: Optional[torch.Tensor] = None,
        sequence_length: int = 1,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Returns sampled actions.
        If memory is enabled, return the memories as well.
        :param obs_only: Observations to be processed that do not have corresponding actions.
            These are encoded with the obs_encoder.
        :param obs: Observations to be processed that do have corresponding actions.
            After concatenation with actions, these are processed with obs_action_encoder.
        :param actions: After concatenation with obs, these are processed with obs_action_encoder.
        :param memories: If using memory, a Tensor of initial memories.
        :param sequence_length: If using memory, the sequence length.
        """
        self_attn_masks = []
        self_attn_inputs = []
        concat_f_inp = []
        if obs:
            obs_attn_mask = self._get_masks_from_nans(obs)
            obs = self._copy_and_remove_nans_from_obs(obs, obs_attn_mask)
            for inputs, action in zip(obs, actions):
                encoded = self.observation_encoder(inputs)
                cat_encodes = [
                    encoded,
                    action.to_flat(self.action_spec.discrete_branches),
                ]
                concat_f_inp.append(torch.cat(cat_encodes, dim=1))
            f_inp = torch.stack(concat_f_inp, dim=1)
            self_attn_masks.append(obs_attn_mask)
            self_attn_inputs.append(self.obs_action_encoder(None, f_inp))

        concat_encoded_obs = []
        if obs_only:
            obs_only_attn_mask = self._get_masks_from_nans(obs_only)
            obs_only = self._copy_and_remove_nans_from_obs(
                obs_only, obs_only_attn_mask)
            for inputs in obs_only:
                encoded = self.observation_encoder(inputs)
                concat_encoded_obs.append(encoded)
            g_inp = torch.stack(concat_encoded_obs, dim=1)
            self_attn_masks.append(obs_only_attn_mask)
            self_attn_inputs.append(self.obs_encoder(None, g_inp))

        encoded_entity = torch.cat(self_attn_inputs, dim=1)
        encoded_state = self.self_attn(encoded_entity, self_attn_masks)

        encoding = self.linear_encoder(encoded_state)
        if self.use_lstm:
            # Resize to (batch, sequence length, encoding size)
            encoding = encoding.reshape([-1, sequence_length, self.h_size])
            encoding, memories = self.lstm(encoding, memories)
            encoding = encoding.reshape([-1, self.m_size // 2])
        return encoding, memories
Exemplo n.º 6
0
    def forward(
        self,
        inputs: List[torch.Tensor],
        actions: Optional[torch.Tensor] = None,
        memories: Optional[torch.Tensor] = None,
        sequence_length: int = 1,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        encodes = []
        var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = []

        for idx, processor in enumerate(self.processors):
            if not isinstance(processor, EntityEmbedding):
                # The input can be encoded without having to process other inputs
                obs_input = inputs[idx]
                processed_obs = processor(obs_input)
                encodes.append(processed_obs)
            else:
                var_len_processor_inputs.append((processor, inputs[idx]))
        if len(encodes) != 0:
            encoded_self = torch.cat(encodes, dim=1)
            input_exist = True
        else:
            input_exist = False
        if len(var_len_processor_inputs) > 0:
            # Some inputs need to be processed with a variable length encoder
            masks = get_zero_entities_mask(
                [p_i[1] for p_i in var_len_processor_inputs])
            embeddings: List[torch.Tensor] = []
            processed_self = self.x_self_encoder(
                encoded_self) if input_exist else None
            for processor, var_len_input in var_len_processor_inputs:
                embeddings.append(processor(processed_self, var_len_input))
            qkv = torch.cat(embeddings, dim=1)
            attention_embedding = self.rsa(qkv, masks)
            if not input_exist:
                encoded_self = torch.cat([attention_embedding], dim=1)
                input_exist = True
            else:
                encoded_self = torch.cat([encoded_self, attention_embedding],
                                         dim=1)

        if not input_exist:
            raise Exception(
                "The trainer was unable to process any of the provided inputs. "
                "Make sure the trained agents has at least one sensor attached to them."
            )

        if actions is not None:
            encoded_self = torch.cat([encoded_self, actions], dim=1)
        encoding = self.linear_encoder(encoded_self)

        if self.use_lstm:
            # Resize to (batch, sequence length, encoding size)
            encoding = encoding.reshape([-1, sequence_length, self.h_size])
            encoding, memories = self.lstm(encoding, memories)
            encoding = encoding.reshape([-1, self.m_size // 2])
        return encoding, memories
Exemplo n.º 7
0
 def generate_input_helper(pattern):
     _input = torch.zeros((batch_size, 0, size))
     for i in range(len(pattern)):
         if i % 2 == 0:
             _input = torch.cat(
                 [_input,
                  torch.rand((batch_size, pattern[i], size))],
                 dim=1)
         else:
             _input = torch.cat(
                 [_input,
                  torch.zeros((batch_size, pattern[i], size))],
                 dim=1)
     return _input
Exemplo n.º 8
0
    def forward(self, inputs: List[torch.Tensor]) -> torch.Tensor:
        """
        Encode observations using a list of processors and an RSA.
        :param inputs: List of Tensors corresponding to a set of obs.
        :param processors: a ModuleList of the input processors to be applied to these obs.
        :param rsa: Optionally, an RSA to use for variable length obs.
        :param x_self_encoder: Optionally, an encoder to use for x_self (in this case, the non-variable inputs.).
        """
        encodes = []
        var_len_processor_inputs: List[Tuple[nn.Module, torch.Tensor]] = []

        for idx, processor in enumerate(self.processors):
            if not isinstance(processor, EntityEmbedding):
                # The input can be encoded without having to process other inputs
                obs_input = inputs[idx]
                processed_obs = processor(obs_input)
                encodes.append(processed_obs)
            else:
                var_len_processor_inputs.append((processor, inputs[idx]))
        if len(encodes) != 0:
            encoded_self = torch.cat(encodes, dim=1)
            input_exist = True
        else:
            input_exist = False
        if len(var_len_processor_inputs) > 0 and self.rsa is not None:
            # Some inputs need to be processed with a variable length encoder
            masks = get_zero_entities_mask(
                [p_i[1] for p_i in var_len_processor_inputs])
            embeddings: List[torch.Tensor] = []
            processed_self = (self.x_self_encoder(encoded_self) if input_exist
                              and self.x_self_encoder is not None else None)
            for processor, var_len_input in var_len_processor_inputs:
                embeddings.append(processor(processed_self, var_len_input))
            qkv = torch.cat(embeddings, dim=1)
            attention_embedding = self.rsa(qkv, masks)
            if not input_exist:
                encoded_self = torch.cat([attention_embedding], dim=1)
                input_exist = True
            else:
                encoded_self = torch.cat([encoded_self, attention_embedding],
                                         dim=1)

        if not input_exist:
            raise UnityTrainerException(
                "The trainer was unable to process any of the provided inputs. "
                "Make sure the trained agents has at least one sensor attached to them."
            )

        return encoded_self
Exemplo n.º 9
0
 def to_flat(self, discrete_branches: List[int]) -> torch.Tensor:
     """
     Flatten this AgentAction into a single torch Tensor of dimension (batch, num_continuous + num_one_hot_discrete).
     Discrete actions are converted into one-hot and concatenated with continuous actions.
     :param discrete_branches: List of sizes for discrete actions.
     :return: Tensor of flattened actions.
     """
     # if there are any discrete actions, create one-hot
     if self.discrete_list is not None and self.discrete_list:
         discrete_oh = ModelUtils.actions_to_onehot(self.discrete_tensor,
                                                    discrete_branches)
         discrete_oh = torch.cat(discrete_oh, dim=1)
     else:
         discrete_oh = torch.empty(0)
     return torch.cat([self.continuous_tensor, discrete_oh], dim=-1)
Exemplo n.º 10
0
 def get_action_stats_and_value(
     self,
     vec_inputs: List[torch.Tensor],
     vis_inputs: List[torch.Tensor],
     masks: Optional[torch.Tensor] = None,
     memories: Optional[torch.Tensor] = None,
     sequence_length: int = 1,
 ) -> Tuple[
     AgentAction, ActionLogProbs, torch.Tensor, Dict[str, torch.Tensor], torch.Tensor
 ]:
     if self.use_lstm:
         # Use only the back half of memories for critic and actor
         actor_mem, critic_mem = torch.split(memories, self.memory_size // 2, dim=-1)
     else:
         critic_mem = None
         actor_mem = None
     encoding, actor_mem_outs = self.network_body(
         vec_inputs, vis_inputs, memories=actor_mem, sequence_length=sequence_length
     )
     action, log_probs, entropies = self.action_model(encoding, masks)
     value_outputs, critic_mem_outs = self.critic(
         vec_inputs, vis_inputs, memories=critic_mem, sequence_length=sequence_length
     )
     if self.use_lstm:
         mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1)
     else:
         mem_out = None
     return action, log_probs, entropies, value_outputs, mem_out
Exemplo n.º 11
0
 def get_goal_encoding(self, inputs: List[torch.Tensor]) -> torch.Tensor:
     """
     Encode observations corresponding to goals using a list of processors.
     :param inputs: List of Tensors corresponding to a set of obs.
     """
     encodes = []
     for idx in self._goal_processor_indices:
         processor = self.processors[idx]
         if not isinstance(processor, EntityEmbedding):
             # The input can be encoded without having to process other inputs
             obs_input = inputs[idx]
             processed_obs = processor(obs_input)
             encodes.append(processed_obs)
         else:
             raise UnityTrainerException(
                 "The one of the goals uses variable length observations. This use "
                 "case is not supported."
             )
     if len(encodes) != 0:
         encoded = torch.cat(encodes, dim=1)
     else:
         raise UnityTrainerException(
             "Trainer was unable to process any of the goals provided as input."
         )
     return encoded
Exemplo n.º 12
0
    def forward(self, inp: torch.Tensor,
                key_masks: List[torch.Tensor]) -> torch.Tensor:
        # Gather the maximum number of entities information
        mask = torch.cat(key_masks, dim=1)

        inp = self.embedding_norm(inp)
        # Feed to self attention
        query = self.fc_q(inp)  # (b, n_q, emb)
        key = self.fc_k(inp)  # (b, n_k, emb)
        value = self.fc_v(inp)  # (b, n_k, emb)

        # Only use max num if provided
        if self.max_num_ent is not None:
            num_ent = self.max_num_ent
        else:
            num_ent = inp.shape[1]
            if exporting_to_onnx.is_exporting():
                raise UnityTrainerException(
                    "Trying to export an attention mechanism that doesn't have a set max \
                    number of elements.")

        output, _ = self.attention(query, key, value, num_ent, num_ent, mask)
        # Residual
        output = self.fc_out(output) + inp
        output = self.residual_norm(output)
        # Average Pooling
        numerator = torch.sum(output * (1 - mask).reshape(-1, num_ent, 1),
                              dim=1)
        denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON
        output = numerator / denominator
        return output
Exemplo n.º 13
0
 def _get_probs_and_entropy(
         self, actions: AgentAction,
         dists: DistInstances) -> Tuple[ActionLogProbs, torch.Tensor]:
     """
     Computes the log probabilites of the actions given distributions and entropies of
     the given distributions.
     :params actions: The AgentAction
     :params dists: The DistInstances tuple
     :return: An ActionLogProbs tuple and a torch tensor of the distribution entropies.
     """
     entropies_list: List[torch.Tensor] = []
     continuous_log_prob: Optional[torch.Tensor] = None
     discrete_log_probs: Optional[List[torch.Tensor]] = None
     all_discrete_log_probs: Optional[List[torch.Tensor]] = None
     # This checks None because mypy complains otherwise
     if dists.continuous is not None:
         continuous_log_prob = dists.continuous.log_prob(
             actions.continuous_tensor)
         entropies_list.append(dists.continuous.entropy())
     if dists.discrete is not None:
         discrete_log_probs = []
         all_discrete_log_probs = []
         for discrete_action, discrete_dist in zip(
                 actions.discrete_list,
                 dists.discrete  # type: ignore
         ):
             discrete_log_prob = discrete_dist.log_prob(discrete_action)
             entropies_list.append(discrete_dist.entropy())
             discrete_log_probs.append(discrete_log_prob)
             all_discrete_log_probs.append(discrete_dist.all_log_prob())
     action_log_probs = ActionLogProbs(continuous_log_prob,
                                       discrete_log_probs,
                                       all_discrete_log_probs)
     entropies = torch.cat(entropies_list, dim=1)
     return action_log_probs, entropies
Exemplo n.º 14
0
 def get_dist_and_value(
     self,
     vec_inputs: List[torch.Tensor],
     vis_inputs: List[torch.Tensor],
     masks: Optional[torch.Tensor] = None,
     memories: Optional[torch.Tensor] = None,
     sequence_length: int = 1,
 ) -> Tuple[List[DistInstance], Dict[str, torch.Tensor], torch.Tensor]:
     if self.use_lstm:
         # Use only the back half of memories for critic and actor
         actor_mem, critic_mem = torch.split(memories,
                                             self.memory_size // 2,
                                             dim=-1)
     else:
         critic_mem = None
         actor_mem = None
     dists, actor_mem_outs = self.get_dists(
         vec_inputs,
         vis_inputs,
         memories=actor_mem,
         sequence_length=sequence_length,
         masks=masks,
     )
     value_outputs, critic_mem_outs = self.critic(
         vec_inputs,
         vis_inputs,
         memories=critic_mem,
         sequence_length=sequence_length)
     if self.use_lstm:
         mem_out = torch.cat([actor_mem_outs, critic_mem_outs], dim=-1)
     else:
         mem_out = None
     return dists, value_outputs, mem_out
 def compute_estimate(self,
                      mini_batch: AgentBuffer,
                      use_vail_noise: bool = False) -> torch.Tensor:
     """
     Given a mini_batch, computes the estimate (How much the discriminator believes
     the data was sampled from the demonstration data).
     :param mini_batch: The AgentBuffer of data
     :param use_vail_noise: Only when using VAIL : If true, will sample the code, if
     false, will return the mean of the code.
     """
     vec_inputs, vis_inputs = self.get_state_inputs(mini_batch)
     if self._settings.use_actions:
         actions = self.get_action_input(mini_batch)
         dones = torch.as_tensor(mini_batch["done"],
                                 dtype=torch.float).unsqueeze(1)
         action_inputs = torch.cat([actions, dones], dim=1)
         hidden, _ = self.encoder(vec_inputs, vis_inputs, action_inputs)
     else:
         hidden, _ = self.encoder(vec_inputs, vis_inputs)
     z_mu: Optional[torch.Tensor] = None
     if self._settings.use_vail:
         z_mu = self._z_mu_layer(hidden)
         hidden = torch.normal(z_mu, self._z_sigma * use_vail_noise)
     estimate = self._estimator(hidden)
     return estimate, z_mu
Exemplo n.º 16
0
 def forward(
     self,
     vec_inputs: List[torch.Tensor],
     vis_inputs: List[torch.Tensor],
     masks: Optional[torch.Tensor] = None,
     memories: Optional[torch.Tensor] = None,
 ) -> Tuple[torch.Tensor, int, int, int, int]:
     """
     Note: This forward() method is required for exporting to ONNX. Don't modify the inputs and outputs.
     """
     dists, _ = self.get_dists(vec_inputs, vis_inputs, masks, memories, 1)
     if self.action_spec.is_continuous():
         action_list = self.sample_action(dists)
         action_out = torch.stack(action_list, dim=-1)
         if self._clip_action_on_export:
             action_out = torch.clamp(action_out, -3, 3) / 3
     else:
         action_out = torch.cat([dist.all_log_prob() for dist in dists],
                                dim=1)
     return (
         action_out,
         self.version_number,
         torch.Tensor([self.network_body.memory_size]),
         self.is_continuous_int,
         self.act_size_vector,
     )
 def predict_action(self, mini_batch: AgentBuffer) -> torch.Tensor:
     """
     In the continuous case, returns the predicted action.
     In the discrete case, returns the logits.
     """
     inverse_model_input = torch.cat((self.get_current_state(mini_batch),
                                      self.get_next_state(mini_batch)),
                                     dim=1)
     hidden = self.inverse_model_action_prediction(inverse_model_input)
     if self._policy_specs.is_action_continuous():
         return hidden
     else:
         branches = ModelUtils.break_into_branches(
             hidden, self._policy_specs.discrete_action_branches)
         branches = [torch.softmax(b, dim=1) for b in branches]
         return torch.cat(branches, dim=1)
Exemplo n.º 18
0
 def forward(self, x_self: torch.Tensor,
             entities: torch.Tensor) -> torch.Tensor:
     if self.self_size > 0:
         num_entities = self.entity_num_max_elements
         if num_entities < 0:
             if exporting_to_onnx.is_exporting():
                 raise UnityTrainerException(
                     "Trying to export an attention mechanism that doesn't have a set max \
                     number of elements.")
             num_entities = entities.shape[1]
         expanded_self = x_self.reshape(-1, 1, self.self_size)
         expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
         # Concatenate all observations with self
         entities = torch.cat([expanded_self, entities], dim=2)
     # Encode entities
     encoded_entities = self.self_ent_encoder(entities)
     return encoded_entities
Exemplo n.º 19
0
 def forward(self, input_tensor: torch.Tensor,
             goal_tensor: torch.Tensor) -> torch.Tensor:  # type: ignore
     activation = torch.cat([input_tensor, goal_tensor], dim=-1)
     for layer in self.layers:
         if isinstance(layer, HyperNetwork):
             activation = layer(activation, goal_tensor)
         else:
             activation = layer(activation)
     return activation
Exemplo n.º 20
0
 def forward(self, action: AgentAction) -> torch.Tensor:
     """
     Returns a tensor corresponding the flattened action
     :param action: An AgentAction object
     """
     action_list: List[torch.Tensor] = []
     if self._specs.continuous_size > 0:
         action_list.append(action.continuous_tensor)
     if self._specs.discrete_size > 0:
         flat_discrete = torch.cat(
             ModelUtils.actions_to_onehot(
                 torch.as_tensor(action.discrete_tensor, dtype=torch.long),
                 self._specs.discrete_branches,
             ),
             dim=1,
         )
         action_list.append(flat_discrete)
     return torch.cat(action_list, dim=1)
Exemplo n.º 21
0
 def forward(self, input_tensor: torch.Tensor,
             memories: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
     # We don't use torch.split here since it is not supported by Barracuda
     h0 = memories[:, :, :self.hidden_size]
     c0 = memories[:, :, self.hidden_size:]
     hidden = (h0, c0)
     lstm_out, hidden_out = self.lstm(input_tensor, hidden)
     output_mem = torch.cat(hidden_out, dim=-1)
     return lstm_out, output_mem
Exemplo n.º 22
0
 def forward(
     self,
     x_self: torch.Tensor,
     entities: List[torch.Tensor],
     key_masks: List[torch.Tensor],
 ) -> torch.Tensor:
     # Gather the maximum number of entities information
     if self.entities_num_max_elements is None:
         self.entities_num_max_elements = []
         for ent in entities:
             self.entities_num_max_elements.append(ent.shape[1])
     # Concatenate all observations with self
     self_and_ent: List[torch.Tensor] = []
     for num_entities, ent in zip(self.entities_num_max_elements, entities):
         expanded_self = x_self.reshape(-1, 1, self.self_size)
         # .repeat(
         #     1, num_entities, 1
         # )
         expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
         self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
     # Generate the tensor that will serve as query, key and value to self attention
     qkv = torch.cat(
         [
             ent_encoder(x)
             for ent_encoder, x in zip(self.ent_encoders, self_and_ent)
         ],
         dim=1,
     )
     mask = torch.cat(key_masks, dim=1)
     # Feed to self attention
     max_num_ent = sum(self.entities_num_max_elements)
     output, _ = self.attention(qkv, qkv, qkv, mask, max_num_ent,
                                max_num_ent)
     # Residual
     output = self.residual_layer(output) + qkv
     # Average Pooling
     numerator = torch.sum(output * (1 - mask).reshape(-1, max_num_ent, 1),
                           dim=1)
     denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPISLON
     output = numerator / denominator
     # Residual between x_self and the output of the module
     output = self.x_self_residual_layer(torch.cat([output, x_self], dim=1))
     return output
Exemplo n.º 23
0
 def forward(self, action: torch.Tensor) -> torch.Tensor:
     if self._specs.is_action_continuous():
         return action
     else:
         return torch.cat(
             ModelUtils.actions_to_onehot(
                 torch.as_tensor(action, dtype=torch.long),
                 self._specs.discrete_action_branches,
             ),
             dim=1,
         )
    def predict_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Uses the current state embedding and the action of the mini_batch to predict
        the next state embedding.
        """
        actions = AgentAction.from_buffer(mini_batch)
        flattened_action = self._action_flattener.forward(actions)
        forward_model_input = torch.cat(
            (self.get_current_state(mini_batch), flattened_action), dim=1)

        return self.forward_model_next_state_prediction(forward_model_input)
    def predict_action(self, mini_batch: AgentBuffer) -> ActionPredictionTuple:
        """
        In the continuous case, returns the predicted action.
        In the discrete case, returns the logits.
        """
        inverse_model_input = torch.cat((self.get_current_state(mini_batch),
                                         self.get_next_state(mini_batch)),
                                        dim=1)

        continuous_pred = None
        discrete_pred = None
        hidden = self.inverse_model_action_encoding(inverse_model_input)
        if self._action_spec.continuous_size > 0:
            continuous_pred = self.continuous_action_prediction(hidden)
        if self._action_spec.discrete_size > 0:
            raw_discrete_pred = self.discrete_action_prediction(hidden)
            branches = ModelUtils.break_into_branches(
                raw_discrete_pred, self._action_spec.discrete_branches)
            branches = [torch.softmax(b, dim=1) for b in branches]
            discrete_pred = torch.cat(branches, dim=1)
        return ActionPredictionTuple(continuous_pred, discrete_pred)
Exemplo n.º 26
0
def test_visual_encoder_trains(vis_class, size):
    torch.manual_seed(0)
    image_size = (size, size, 1)
    batch = 100

    inputs = torch.cat([
        torch.zeros((batch, ) + image_size),
        torch.ones((batch, ) + image_size)
    ],
                       dim=0)
    target = torch.cat([torch.zeros((batch, )), torch.ones((batch, ))], dim=0)
    enc = vis_class(image_size[0], image_size[1], image_size[2], 1)
    optimizer = torch.optim.Adam(enc.parameters(), lr=0.001)

    for _ in range(15):
        prediction = enc(inputs)[:, 0]
        loss = torch.mean((target - prediction)**2)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
    assert loss.item() < 0.05
    def predict_next_state(self, mini_batch: AgentBuffer) -> torch.Tensor:
        """
        Uses the current state embedding and the action of the mini_batch to predict
        the next state embedding.
        """
        if self._policy_specs.is_action_continuous():
            action = ModelUtils.list_to_tensor(mini_batch["actions"],
                                               dtype=torch.float)
        else:
            action = torch.cat(
                ModelUtils.actions_to_onehot(
                    ModelUtils.list_to_tensor(mini_batch["actions"],
                                              dtype=torch.long),
                    self._policy_specs.discrete_action_branches,
                ),
                dim=1,
            )
        forward_model_input = torch.cat(
            (self.get_current_state(mini_batch), action), dim=1)

        return self.forward_model_next_state_prediction(forward_model_input)
Exemplo n.º 28
0
 def compute_gradient_magnitude(self, policy_batch: AgentBuffer,
                                expert_batch: AgentBuffer) -> torch.Tensor:
     """
     Gradient penalty from https://arxiv.org/pdf/1704.00028. Adds stability esp.
     for off-policy. Compute gradients w.r.t randomly interpolated input.
     """
     policy_inputs = self.get_state_inputs(policy_batch)
     expert_inputs = self.get_state_inputs(expert_batch)
     interp_inputs = []
     for policy_input, expert_input in zip(policy_inputs, expert_inputs):
         obs_epsilon = torch.rand(policy_input.shape)
         interp_input = obs_epsilon * policy_input + (
             1 - obs_epsilon) * expert_input
         interp_input.requires_grad = True  # For gradient calculation
         interp_inputs.append(interp_input)
     if self._settings.use_actions:
         policy_action = self.get_action_input(policy_batch)
         expert_action = self.get_action_input(expert_batch)
         action_epsilon = torch.rand(policy_action.shape)
         policy_dones = torch.as_tensor(policy_batch[BufferKey.DONE],
                                        dtype=torch.float).unsqueeze(1)
         expert_dones = torch.as_tensor(expert_batch[BufferKey.DONE],
                                        dtype=torch.float).unsqueeze(1)
         dones_epsilon = torch.rand(policy_dones.shape)
         action_inputs = torch.cat(
             [
                 action_epsilon * policy_action +
                 (1 - action_epsilon) * expert_action,
                 dones_epsilon * policy_dones +
                 (1 - dones_epsilon) * expert_dones,
             ],
             dim=1,
         )
         action_inputs.requires_grad = True
         hidden, _ = self.encoder(interp_inputs, action_inputs)
         encoder_input = tuple(interp_inputs + [action_inputs])
     else:
         hidden, _ = self.encoder(interp_inputs)
         encoder_input = tuple(interp_inputs)
     if self._settings.use_vail:
         use_vail_noise = True
         z_mu = self._z_mu_layer(hidden)
         hidden = z_mu + torch.randn_like(
             z_mu) * self._z_sigma * use_vail_noise
     estimate = self._estimator(hidden).squeeze(1).sum()
     gradient = torch.autograd.grad(estimate,
                                    encoder_input,
                                    create_graph=True)[0]
     # Norm's gradient could be NaN at 0. Use our own safe_norm
     safe_norm = (torch.sum(gradient**2, dim=1) + self.EPSILON).sqrt()
     gradient_mag = torch.mean((safe_norm - 1)**2)
     return gradient_mag
Exemplo n.º 29
0
 def forward(self, x_self: torch.Tensor,
             entities: List[torch.Tensor]) -> Tuple[torch.Tensor, int]:
     if self.concat_self:
         # Concatenate all observations with self
         self_and_ent: List[torch.Tensor] = []
         for num_entities, ent in zip(self.entity_num_max_elements,
                                      entities):
             expanded_self = x_self.reshape(-1, 1, self.self_size)
             expanded_self = torch.cat([expanded_self] * num_entities,
                                       dim=1)
             self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
     else:
         self_and_ent = entities
         # Encode and concatenate entites
     encoded_entities = torch.cat(
         [
             ent_encoder(x)
             for ent_encoder, x in zip(self.ent_encoders, self_and_ent)
         ],
         dim=1,
     )
     return encoded_entities
Exemplo n.º 30
0
 def forward(self, inputs: torch.Tensor) -> List[DistInstance]:
     mu = self.mu(inputs)
     if self.conditional_sigma:
         log_sigma = torch.clamp(self.log_sigma(inputs), min=-20, max=2)
     else:
         # Expand so that entropy matches batch size. Note that we're using
         # torch.cat here instead of torch.expand() becuase it is not supported in the
         # verified version of Barracuda (1.0.2).
         log_sigma = torch.cat([self.log_sigma] * inputs.shape[0], axis=0)
     if self.tanh_squash:
         return [TanhGaussianDistInstance(mu, torch.exp(log_sigma))]
     else:
         return [GaussianDistInstance(mu, torch.exp(log_sigma))]