def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: num_entities = self.entity_num_max_elements if num_entities < 0: if exporting_to_onnx.is_exporting(): raise UnityTrainerException( "Trying to export an attention mechanism that doesn't have a set max \ number of elements.") num_entities = entities.shape[1] if exporting_to_onnx.is_exporting(): # When exporting to ONNX, we want to transpose the entities. This is # because ONNX only support input in NCHW (channel first) format. # Barracuda also expect to get data in NCHW. entities = torch.transpose(entities, 2, 1).reshape(-1, num_entities, self.entity_size) if self.self_size > 0: expanded_self = x_self.reshape(-1, 1, self.self_size) expanded_self = torch.cat([expanded_self] * num_entities, dim=1) # Concatenate all observations with self entities = torch.cat([expanded_self, entities], dim=2) # Encode entities encoded_entities = self.self_ent_encoder(entities) return encoded_entities
def forward(self, inp: torch.Tensor, key_masks: List[torch.Tensor]) -> torch.Tensor: # Gather the maximum number of entities information mask = torch.cat(key_masks, dim=1) inp = self.embedding_norm(inp) # Feed to self attention query = self.fc_q(inp) # (b, n_q, emb) key = self.fc_k(inp) # (b, n_k, emb) value = self.fc_v(inp) # (b, n_k, emb) # Only use max num if provided if self.max_num_ent is not None: num_ent = self.max_num_ent else: num_ent = inp.shape[1] if exporting_to_onnx.is_exporting(): raise UnityTrainerException( "Trying to export an attention mechanism that doesn't have a set max \ number of elements.") output, _ = self.attention(query, key, value, num_ent, num_ent, mask) # Residual output = self.fc_out(output) + inp output = self.residual_norm(output) # Average Pooling numerator = torch.sum(output * (1 - mask).reshape(-1, num_ent, 1), dim=1) denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON output = numerator / denominator return output
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor: if not exporting_to_onnx.is_exporting(): visual_obs = visual_obs.permute([0, 3, 1, 2]) batch_size = visual_obs.shape[0] hidden = self.sequential(visual_obs) before_out = hidden.reshape(batch_size, -1) return torch.relu(self.dense(before_out))
def get_zero_entities_mask(entities: List[torch.Tensor]) -> List[torch.Tensor]: """ Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was all zeros (on dimension 2) and 0 otherwise. This is used in the Attention layer to mask the padding observations. """ with torch.no_grad(): if exporting_to_onnx.is_exporting(): with warnings.catch_warnings(): # We ignore a TracerWarning from PyTorch that warns that doing # shape[n].item() will cause the trace to be incorrect (the trace might # not generalize to other inputs) # We ignore this warning because we know the model will always be # run with inputs of the same shape warnings.simplefilter("ignore") # When exporting to ONNX, we want to transpose the entities. This is # because ONNX only support input in NCHW (channel first) format. # Barracuda also expect to get data in NCHW. entities = [ torch.transpose(obs, 2, 1).reshape(-1, obs.shape[1].item(), obs.shape[2].item()) for obs in entities ] # Generate the masking tensors for each entities tensor (mask only if all zeros) key_masks: List[torch.Tensor] = [ (torch.sum(ent**2, axis=2) < 0.01).float() for ent in entities ] return key_masks
def forward(self, x_self: torch.Tensor, entities: List[torch.Tensor]) -> Tuple[torch.Tensor, int]: if self.concat_self: # Concatenate all observations with self self_and_ent: List[torch.Tensor] = [] for num_entities, ent in zip(self.entity_num_max_elements, entities): if num_entities < 0: if exporting_to_onnx.is_exporting(): raise UnityTrainerException( "Trying to export an attention mechanism that doesn't have a set max \ number of elements.") num_entities = ent.shape[1] expanded_self = x_self.reshape(-1, 1, self.self_size) expanded_self = torch.cat([expanded_self] * num_entities, dim=1) self_and_ent.append(torch.cat([expanded_self, ent], dim=2)) else: self_and_ent = entities # Encode and concatenate entites encoded_entities = torch.cat( [ ent_encoder(x) for ent_encoder, x in zip(self.ent_encoders, self_and_ent) ], dim=1, ) encoded_entities = self.embedding_norm(encoded_entities) return encoded_entities
def forward(self, input_tensor: torch.Tensor, memories: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: if exporting_to_onnx.is_exporting(): # This transpose is needed both at input and output of the LSTM when # exporting because ONNX will expect (sequence_len, batch, memory_size) # instead of (batch, sequence_len, memory_size) memories = torch.transpose(memories, 0, 1) # We don't use torch.split here since it is not supported by Barracuda h0 = memories[:, :, :self.hidden_size].contiguous() c0 = memories[:, :, self.hidden_size:].contiguous() hidden = (h0, c0) lstm_out, hidden_out = self.lstm(input_tensor, hidden) output_mem = torch.cat(hidden_out, dim=-1) if exporting_to_onnx.is_exporting(): output_mem = torch.transpose(output_mem, 0, 1) return lstm_out, output_mem
def forward(self, x_self: torch.Tensor, entities: torch.Tensor) -> torch.Tensor: if self.self_size > 0: num_entities = self.entity_num_max_elements if num_entities < 0: if exporting_to_onnx.is_exporting(): raise UnityTrainerException( "Trying to export an attention mechanism that doesn't have a set max \ number of elements.") num_entities = entities.shape[1] expanded_self = x_self.reshape(-1, 1, self.self_size) expanded_self = torch.cat([expanded_self] * num_entities, dim=1) # Concatenate all observations with self entities = torch.cat([expanded_self, entities], dim=2) # Encode entities encoded_entities = self.self_ent_encoder(entities) return encoded_entities
def forward( self, vec_inputs: List[torch.Tensor], vis_inputs: List[torch.Tensor], actions: Optional[torch.Tensor] = None, memories: Optional[torch.Tensor] = None, sequence_length: int = 1, ) -> Tuple[torch.Tensor, torch.Tensor]: encodes = [] for idx, processor in enumerate(self.vector_processors): vec_input = vec_inputs[idx] processed_vec = processor(vec_input) encodes.append(processed_vec) for idx, processor in enumerate(self.visual_processors): vis_input = vis_inputs[idx] if not exporting_to_onnx.is_exporting(): vis_input = vis_input.permute([0, 3, 1, 2]) processed_vis = processor(vis_input) encodes.append(processed_vis) if len(encodes) == 0: raise Exception("No valid inputs to network.") # Constants don't work in Barracuda if actions is not None: inputs = torch.cat(encodes + [actions], dim=-1) else: inputs = torch.cat(encodes, dim=-1) encoding = self.linear_encoder(inputs) if self.use_lstm: # Resize to (batch, sequence length, encoding size) encoding = encoding.reshape([-1, sequence_length, self.h_size]) encoding, memories = self.lstm(encoding, memories) encoding = encoding.reshape([-1, self.m_size // 2]) return encoding, memories
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor: if not exporting_to_onnx.is_exporting(): visual_obs = visual_obs.permute([0, 3, 1, 2]) hidden = self.sequential(visual_obs) before_out = hidden.reshape(-1, self.final_flat_size) return torch.relu(self.dense(before_out))
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor: if not exporting_to_onnx.is_exporting(): visual_obs = visual_obs.permute([0, 3, 1, 2]) hidden = self.conv_layers(visual_obs) hidden = hidden.reshape([-1, self.final_flat]) return self.dense(hidden)
def forward(self, visual_obs: torch.Tensor) -> torch.Tensor: if not exporting_to_onnx.is_exporting(): visual_obs = visual_obs.permute([0, 3, 1, 2]) hidden = visual_obs.reshape(-1, self.input_size) return self.dense(hidden)