def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, key_mask: Optional[torch.Tensor] = None, number_of_keys: int = -1, number_of_queries: int = -1, ) -> Tuple[torch.Tensor, torch.Tensor]: b = -1 # the batch size # This is to avoid using .size() when possible as Barracuda does not support n_q = number_of_queries if number_of_queries != -1 else query.size(1) n_k = number_of_keys if number_of_keys != -1 else key.size(1) query = self.fc_q(query) # (b, n_q, h*d) key = self.fc_k(key) # (b, n_k, h*d) value = self.fc_v(value) # (b, n_k, h*d) query = query.reshape(b, n_q, self.n_heads, self.embedding_size) key = key.reshape(b, n_k, self.n_heads, self.embedding_size) value = value.reshape(b, n_k, self.n_heads, self.embedding_size) query = query.permute([0, 2, 1, 3]) # (b, h, n_q, emb) # The next few lines are equivalent to : key.permute([0, 2, 3, 1]) # This is a hack, ONNX will compress two permute operations and # Barracuda will not like seeing `permute([0,2,3,1])` key = key.permute([0, 2, 1, 3]) # (b, h, emb, n_k) key -= 1 key += 1 key = key.permute([0, 1, 3, 2]) # (b, h, emb, n_k) qk = torch.matmul(query, key) # (b, h, n_q, n_k) if key_mask is None: qk = qk / (self.embedding_size**0.5) else: key_mask = key_mask.reshape(b, 1, 1, n_k) qk = (1 - key_mask) * qk / (self.embedding_size** 0.5) + key_mask * self.NEG_INF att = torch.softmax(qk, dim=3) # (b, h, n_q, n_k) value = value.permute([0, 2, 1, 3]) # (b, h, n_k, emb) value_attention = torch.matmul(att, value) # (b, h, n_q, emb) value_attention = value_attention.permute([0, 2, 1, 3]) # (b, n_q, h, emb) value_attention = value_attention.reshape( b, n_q, self.n_heads * self.embedding_size) # (b, n_q, h*emb) out = self.fc_out(value_attention) # (b, n_q, emb) return out, att
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: num_entities = self.entity_num_max_elements if num_entities < 0: if exporting_to_onnx.is_exporting(): raise UnityTrainerException( "Trying to export an attention mechanism that doesn't have a set max \ number of elements.") num_entities = entities.shape[1] if exporting_to_onnx.is_exporting(): # When exporting to ONNX, we want to transpose the entities. This is # because ONNX only support input in NCHW (channel first) format. # Barracuda also expect to get data in NCHW. entities = torch.transpose(entities, 2, 1).reshape(-1, num_entities, self.entity_size) if self.self_size > 0: expanded_self = x_self.reshape(-1, 1, self.self_size) expanded_self = torch.cat([expanded_self] * num_entities, dim=1) # Concatenate all observations with self entities = torch.cat([expanded_self, entities], dim=2) # Encode entities encoded_entities = self.self_ent_encoder(entities) return encoded_entities
def forward(self, x_self: torch.Tensor, entities: List[torch.Tensor]) -> Tuple[torch.Tensor, int]: if self.concat_self: # Concatenate all observations with self self_and_ent: List[torch.Tensor] = [] for num_entities, ent in zip(self.entity_num_max_elements, entities): if num_entities < 0: if exporting_to_onnx.is_exporting(): raise UnityTrainerException( "Trying to export an attention mechanism that doesn't have a set max \ number of elements.") num_entities = ent.shape[1] expanded_self = x_self.reshape(-1, 1, self.self_size) expanded_self = torch.cat([expanded_self] * num_entities, dim=1) self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) else: self_and_ent = entities # Encode and concatenate entites encoded_entities = torch.cat( [ ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent) ], dim=1, ) encoded_entities = self.embedding_norm(encoded_entities) return encoded_entities
def forward( self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor, n_q: int, n_k: int, key_mask: Optional[torch.Tensor] = None, ) -> Tuple[torch.Tensor, torch.Tensor]: b = -1 # the batch size query = query.reshape(b, n_q, self.n_heads, self.head_size) # (b, n_q, h, emb / h) key = key.reshape(b, n_k, self.n_heads, self.head_size) # (b, n_k, h, emb / h) value = value.reshape(b, n_k, self.n_heads, self.head_size) # (b, n_k, h, emb / h) query = query.permute([0, 2, 1, 3]) # (b, h, n_q, emb / h) # The next few lines are equivalent to : key.permute([0, 2, 3, 1]) # This is a hack, ONNX will compress two permute operations and # Barracuda will not like seeing `permute([0,2,3,1])` key = key.permute([0, 2, 1, 3]) # (b, h, emb / h, n_k) key -= 1 key += 1 key = key.permute([0, 1, 3, 2]) # (b, h, emb / h, n_k) qk = torch.matmul(query, key) # (b, h, n_q, n_k) if key_mask is None: qk = qk / (self.embedding_size**0.5) else: key_mask = key_mask.reshape(b, 1, 1, n_k) qk = (1 - key_mask) * qk / (self.embedding_size** 0.5) + key_mask * self.NEG_INF att = torch.softmax(qk, dim=3) # (b, h, n_q, n_k) value = value.permute([0, 2, 1, 3]) # (b, h, n_k, emb / h) value_attention = torch.matmul(att, value) # (b, h, n_q, emb / h) value_attention = value_attention.permute([0, 2, 1, 3]) # (b, n_q, h, emb / h) value_attention = value_attention.reshape( b, n_q, self.embedding_size) # (b, n_q, emb) return value_attention, att
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: if self.self_size > 0: num_entities = self.entity_num_max_elements if num_entities < 0: if exporting_to_onnx.is_exporting(): raise UnityTrainerException( "Trying to export an attention mechanism that doesn't have a set max \ number of elements.") num_entities = entities.shape[1] expanded_self = x_self.reshape(-1, 1, self.self_size) expanded_self = torch.cat([expanded_self] * num_entities, dim=1) # Concatenate all observations with self entities = torch.cat([expanded_self, entities], dim=2) # Encode entities encoded_entities = self.self_ent_encoder(entities) return encoded_entities
def forward( self, x_self: torch.Tensor, entities: List[torch.Tensor], key_masks: List[torch.Tensor], ) -> torch.Tensor: # Gather the maximum number of entities information if self.entities_num_max_elements is None: self.entities_num_max_elements = [] for ent in entities: self.entities_num_max_elements.append(ent.shape[1]) # Concatenate all observations with self self_and_ent: List[torch.Tensor] = [] for num_entities, ent in zip(self.entities_num_max_elements, entities): expanded_self = x_self.reshape(-1, 1, self.self_size) # .repeat( # 1, num_entities, 1 # ) expanded_self = torch.cat([expanded_self] * num_entities, dim=1) self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) # Generate the tensor that will serve as query, key and value to self attention qkv = torch.cat( [ ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent) ], dim=1, ) mask = torch.cat(key_masks, dim=1) # Feed to self attention max_num_ent = sum(self.entities_num_max_elements) output, _ = self.attention(qkv, qkv, qkv, mask, max_num_ent, max_num_ent) # Residual output = self.residual_layer(output) + qkv # Average Pooling numerator = torch.sum(output * (1 - mask).reshape(-1, max_num_ent, 1), dim=1) denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPISLON output = numerator / denominator # Residual between x_self and the output of the module output = self.x_self_residual_layer(torch.cat([output, x_self], dim=1)) return output
def forward(self, x_self: torch.Tensor, entities: List[torch.Tensor]) -> Tuple[torch.Tensor, int]: if self.concat_self: # Concatenate all observations with self self_and_ent: List[torch.Tensor] = [] for num_entities, ent in zip(self.entity_num_max_elements, entities): expanded_self = x_self.reshape(-1, 1, self.self_size) expanded_self = torch.cat([expanded_self] * num_entities, dim=1) self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) else: self_and_ent = entities # Encode and concatenate entites encoded_entities = torch.cat( [ ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent) ], dim=1, ) return encoded_entities
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor: if not exporting_to_onnx.is_exporting(): visual_obs = visual_obs.permute([0, 3, 1, 2]) hidden = visual_obs.reshape(-1, self.input_size) return self.dense(hidden)