コード例 #1
0
ファイル: attention.py プロジェクト: zereyak13/ml-agents
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        key_mask: Optional[torch.Tensor] = None,
        number_of_keys: int = -1,
        number_of_queries: int = -1,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        b = -1  # the batch size
        # This is to avoid using .size() when possible as Barracuda does not support
        n_q = number_of_queries if number_of_queries != -1 else query.size(1)
        n_k = number_of_keys if number_of_keys != -1 else key.size(1)

        query = self.fc_q(query)  # (b, n_q, h*d)
        key = self.fc_k(key)  # (b, n_k, h*d)
        value = self.fc_v(value)  # (b, n_k, h*d)

        query = query.reshape(b, n_q, self.n_heads, self.embedding_size)
        key = key.reshape(b, n_k, self.n_heads, self.embedding_size)
        value = value.reshape(b, n_k, self.n_heads, self.embedding_size)

        query = query.permute([0, 2, 1, 3])  # (b, h, n_q, emb)
        # The next few lines are equivalent to : key.permute([0, 2, 3, 1])
        # This is a hack, ONNX will compress two permute operations and
        # Barracuda will not like seeing `permute([0,2,3,1])`
        key = key.permute([0, 2, 1, 3])  # (b, h, emb, n_k)
        key -= 1
        key += 1
        key = key.permute([0, 1, 3, 2])  # (b, h, emb, n_k)

        qk = torch.matmul(query, key)  # (b, h, n_q, n_k)

        if key_mask is None:
            qk = qk / (self.embedding_size**0.5)
        else:
            key_mask = key_mask.reshape(b, 1, 1, n_k)
            qk = (1 - key_mask) * qk / (self.embedding_size**
                                        0.5) + key_mask * self.NEG_INF

        att = torch.softmax(qk, dim=3)  # (b, h, n_q, n_k)

        value = value.permute([0, 2, 1, 3])  # (b, h, n_k, emb)
        value_attention = torch.matmul(att, value)  # (b, h, n_q, emb)

        value_attention = value_attention.permute([0, 2, 1,
                                                   3])  # (b, n_q, h, emb)
        value_attention = value_attention.reshape(
            b, n_q, self.n_heads * self.embedding_size)  # (b, n_q, h*emb)

        out = self.fc_out(value_attention)  # (b, n_q, emb)
        return out, att
コード例 #2
0
    def forward(self, x_self: torch.Tensor,
                entities: torch.Tensor) -> torch.Tensor:
        num_entities = self.entity_num_max_elements
        if num_entities < 0:
            if exporting_to_onnx.is_exporting():
                raise UnityTrainerException(
                    "Trying to export an attention mechanism that doesn't have a set max \
                    number of elements.")
            num_entities = entities.shape[1]

        if exporting_to_onnx.is_exporting():
            # When exporting to ONNX, we want to transpose the entities. This is
            # because ONNX only support input in NCHW (channel first) format.
            # Barracuda also expect to get data in NCHW.
            entities = torch.transpose(entities, 2,
                                       1).reshape(-1, num_entities,
                                                  self.entity_size)

        if self.self_size > 0:
            expanded_self = x_self.reshape(-1, 1, self.self_size)
            expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
            # Concatenate all observations with self
            entities = torch.cat([expanded_self, entities], dim=2)
        # Encode entities
        encoded_entities = self.self_ent_encoder(entities)
        return encoded_entities
コード例 #3
0
 def forward(self, x_self: torch.Tensor,
             entities: List[torch.Tensor]) -> Tuple[torch.Tensor, int]:
     if self.concat_self:
         # Concatenate all observations with self
         self_and_ent: List[torch.Tensor] = []
         for num_entities, ent in zip(self.entity_num_max_elements,
                                      entities):
             if num_entities < 0:
                 if exporting_to_onnx.is_exporting():
                     raise UnityTrainerException(
                         "Trying to export an attention mechanism that doesn't have a set max \
                         number of elements.")
                 num_entities = ent.shape[1]
             expanded_self = x_self.reshape(-1, 1, self.self_size)
             expanded_self = torch.cat([expanded_self] * num_entities,
                                       dim=1)
             self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
     else:
         self_and_ent = entities
         # Encode and concatenate entites
     encoded_entities = torch.cat(
         [
             ent_encoder(x)
             for ent_encoder, x in zip(self.ent_encoders, self_and_ent)
         ],
         dim=1,
     )
     encoded_entities = self.embedding_norm(encoded_entities)
     return encoded_entities
コード例 #4
0
    def forward(
        self,
        query: torch.Tensor,
        key: torch.Tensor,
        value: torch.Tensor,
        n_q: int,
        n_k: int,
        key_mask: Optional[torch.Tensor] = None,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        b = -1  # the batch size

        query = query.reshape(b, n_q, self.n_heads,
                              self.head_size)  # (b, n_q, h, emb / h)
        key = key.reshape(b, n_k, self.n_heads,
                          self.head_size)  # (b, n_k, h, emb / h)
        value = value.reshape(b, n_k, self.n_heads,
                              self.head_size)  # (b, n_k, h, emb / h)

        query = query.permute([0, 2, 1, 3])  # (b, h, n_q, emb / h)
        # The next few lines are equivalent to : key.permute([0, 2, 3, 1])
        # This is a hack, ONNX will compress two permute operations and
        # Barracuda will not like seeing `permute([0,2,3,1])`
        key = key.permute([0, 2, 1, 3])  # (b, h, emb / h, n_k)
        key -= 1
        key += 1
        key = key.permute([0, 1, 3, 2])  # (b, h, emb / h, n_k)

        qk = torch.matmul(query, key)  # (b, h, n_q, n_k)

        if key_mask is None:
            qk = qk / (self.embedding_size**0.5)
        else:
            key_mask = key_mask.reshape(b, 1, 1, n_k)
            qk = (1 - key_mask) * qk / (self.embedding_size**
                                        0.5) + key_mask * self.NEG_INF

        att = torch.softmax(qk, dim=3)  # (b, h, n_q, n_k)

        value = value.permute([0, 2, 1, 3])  # (b, h, n_k, emb / h)
        value_attention = torch.matmul(att, value)  # (b, h, n_q, emb / h)

        value_attention = value_attention.permute([0, 2, 1,
                                                   3])  # (b, n_q, h, emb / h)
        value_attention = value_attention.reshape(
            b, n_q, self.embedding_size)  # (b, n_q, emb)

        return value_attention, att
コード例 #5
0
 def forward(self, x_self: torch.Tensor,
             entities: torch.Tensor) -> torch.Tensor:
     if self.self_size > 0:
         num_entities = self.entity_num_max_elements
         if num_entities < 0:
             if exporting_to_onnx.is_exporting():
                 raise UnityTrainerException(
                     "Trying to export an attention mechanism that doesn't have a set max \
                     number of elements.")
             num_entities = entities.shape[1]
         expanded_self = x_self.reshape(-1, 1, self.self_size)
         expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
         # Concatenate all observations with self
         entities = torch.cat([expanded_self, entities], dim=2)
     # Encode entities
     encoded_entities = self.self_ent_encoder(entities)
     return encoded_entities
コード例 #6
0
ファイル: attention.py プロジェクト: zereyak13/ml-agents
 def forward(
     self,
     x_self: torch.Tensor,
     entities: List[torch.Tensor],
     key_masks: List[torch.Tensor],
 ) -> torch.Tensor:
     # Gather the maximum number of entities information
     if self.entities_num_max_elements is None:
         self.entities_num_max_elements = []
         for ent in entities:
             self.entities_num_max_elements.append(ent.shape[1])
     # Concatenate all observations with self
     self_and_ent: List[torch.Tensor] = []
     for num_entities, ent in zip(self.entities_num_max_elements, entities):
         expanded_self = x_self.reshape(-1, 1, self.self_size)
         # .repeat(
         #     1, num_entities, 1
         # )
         expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
         self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
     # Generate the tensor that will serve as query, key and value to self attention
     qkv = torch.cat(
         [
             ent_encoder(x)
             for ent_encoder, x in zip(self.ent_encoders, self_and_ent)
         ],
         dim=1,
     )
     mask = torch.cat(key_masks, dim=1)
     # Feed to self attention
     max_num_ent = sum(self.entities_num_max_elements)
     output, _ = self.attention(qkv, qkv, qkv, mask, max_num_ent,
                                max_num_ent)
     # Residual
     output = self.residual_layer(output) + qkv
     # Average Pooling
     numerator = torch.sum(output * (1 - mask).reshape(-1, max_num_ent, 1),
                           dim=1)
     denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPISLON
     output = numerator / denominator
     # Residual between x_self and the output of the module
     output = self.x_self_residual_layer(torch.cat([output, x_self], dim=1))
     return output
コード例 #7
0
 def forward(self, x_self: torch.Tensor,
             entities: List[torch.Tensor]) -> Tuple[torch.Tensor, int]:
     if self.concat_self:
         # Concatenate all observations with self
         self_and_ent: List[torch.Tensor] = []
         for num_entities, ent in zip(self.entity_num_max_elements,
                                      entities):
             expanded_self = x_self.reshape(-1, 1, self.self_size)
             expanded_self = torch.cat([expanded_self] * num_entities,
                                       dim=1)
             self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
     else:
         self_and_ent = entities
         # Encode and concatenate entites
     encoded_entities = torch.cat(
         [
             ent_encoder(x)
             for ent_encoder, x in zip(self.ent_encoders, self_and_ent)
         ],
         dim=1,
     )
     return encoded_entities
コード例 #8
0
 def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
     if not exporting_to_onnx.is_exporting():
         visual_obs = visual_obs.permute([0, 3, 1, 2])
     hidden = visual_obs.reshape(-1, self.input_size)
     return self.dense(hidden)