示例#1
0
    def forward(self, x_self: torch.Tensor,
                entities: torch.Tensor) -> torch.Tensor:
        num_entities = self.entity_num_max_elements
        if num_entities < 0:
            if exporting_to_onnx.is_exporting():
                raise UnityTrainerException(
                    "Trying to export an attention mechanism that doesn't have a set max \
                    number of elements.")
            num_entities = entities.shape[1]

        if exporting_to_onnx.is_exporting():
            # When exporting to ONNX, we want to transpose the entities. This is
            # because ONNX only support input in NCHW (channel first) format.
            # Barracuda also expect to get data in NCHW.
            entities = torch.transpose(entities, 2,
                                       1).reshape(-1, num_entities,
                                                  self.entity_size)

        if self.self_size > 0:
            expanded_self = x_self.reshape(-1, 1, self.self_size)
            expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
            # Concatenate all observations with self
            entities = torch.cat([expanded_self, entities], dim=2)
        # Encode entities
        encoded_entities = self.self_ent_encoder(entities)
        return encoded_entities
示例#2
0
    def forward(self, inp: torch.Tensor,
                key_masks: List[torch.Tensor]) -> torch.Tensor:
        # Gather the maximum number of entities information
        mask = torch.cat(key_masks, dim=1)

        inp = self.embedding_norm(inp)
        # Feed to self attention
        query = self.fc_q(inp)  # (b, n_q, emb)
        key = self.fc_k(inp)  # (b, n_k, emb)
        value = self.fc_v(inp)  # (b, n_k, emb)

        # Only use max num if provided
        if self.max_num_ent is not None:
            num_ent = self.max_num_ent
        else:
            num_ent = inp.shape[1]
            if exporting_to_onnx.is_exporting():
                raise UnityTrainerException(
                    "Trying to export an attention mechanism that doesn't have a set max \
                    number of elements.")

        output, _ = self.attention(query, key, value, num_ent, num_ent, mask)
        # Residual
        output = self.fc_out(output) + inp
        output = self.residual_norm(output)
        # Average Pooling
        numerator = torch.sum(output * (1 - mask).reshape(-1, num_ent, 1),
                              dim=1)
        denominator = torch.sum(1 - mask, dim=1, keepdim=True) + self.EPSILON
        output = numerator / denominator
        return output
示例#3
0
 def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
     if not exporting_to_onnx.is_exporting():
         visual_obs = visual_obs.permute([0, 3, 1, 2])
     batch_size = visual_obs.shape[0]
     hidden = self.sequential(visual_obs)
     before_out = hidden.reshape(batch_size, -1)
     return torch.relu(self.dense(before_out))
示例#4
0
def get_zero_entities_mask(entities: List[torch.Tensor]) -> List[torch.Tensor]:
    """
    Takes a List of Tensors and returns a List of mask Tensor with 1 if the input was
    all zeros (on dimension 2) and 0 otherwise. This is used in the Attention
    layer to mask the padding observations.
    """
    with torch.no_grad():

        if exporting_to_onnx.is_exporting():
            with warnings.catch_warnings():
                # We ignore a TracerWarning from PyTorch that warns that doing
                # shape[n].item() will cause the trace to be incorrect (the trace might
                # not generalize to other inputs)
                # We ignore this warning because we know the model will always be
                # run with inputs of the same shape
                warnings.simplefilter("ignore")
                # When exporting to ONNX, we want to transpose the entities. This is
                # because ONNX only support input in NCHW (channel first) format.
                # Barracuda also expect to get data in NCHW.
                entities = [
                    torch.transpose(obs, 2, 1).reshape(-1, obs.shape[1].item(),
                                                       obs.shape[2].item())
                    for obs in entities
                ]

        # Generate the masking tensors for each entities tensor (mask only if all zeros)
        key_masks: List[torch.Tensor] = [
            (torch.sum(ent**2, axis=2) < 0.01).float() for ent in entities
        ]
    return key_masks
示例#5
0
 def forward(self, x_self: torch.Tensor,
             entities: List[torch.Tensor]) -> Tuple[torch.Tensor, int]:
     if self.concat_self:
         # Concatenate all observations with self
         self_and_ent: List[torch.Tensor] = []
         for num_entities, ent in zip(self.entity_num_max_elements,
                                      entities):
             if num_entities < 0:
                 if exporting_to_onnx.is_exporting():
                     raise UnityTrainerException(
                         "Trying to export an attention mechanism that doesn't have a set max \
                         number of elements.")
                 num_entities = ent.shape[1]
             expanded_self = x_self.reshape(-1, 1, self.self_size)
             expanded_self = torch.cat([expanded_self] * num_entities,
                                       dim=1)
             self_and_ent.append(torch.cat([expanded_self, ent], dim=2))
     else:
         self_and_ent = entities
         # Encode and concatenate entites
     encoded_entities = torch.cat(
         [
             ent_encoder(x)
             for ent_encoder, x in zip(self.ent_encoders, self_and_ent)
         ],
         dim=1,
     )
     encoded_entities = self.embedding_norm(encoded_entities)
     return encoded_entities
示例#6
0
    def forward(self, input_tensor: torch.Tensor,
                memories: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:

        if exporting_to_onnx.is_exporting():
            # This transpose is needed both at input and output of the LSTM when
            # exporting because ONNX will expect (sequence_len, batch, memory_size)
            # instead of (batch, sequence_len, memory_size)
            memories = torch.transpose(memories, 0, 1)

        # We don't use torch.split here since it is not supported by Barracuda
        h0 = memories[:, :, :self.hidden_size].contiguous()
        c0 = memories[:, :, self.hidden_size:].contiguous()

        hidden = (h0, c0)
        lstm_out, hidden_out = self.lstm(input_tensor, hidden)
        output_mem = torch.cat(hidden_out, dim=-1)

        if exporting_to_onnx.is_exporting():
            output_mem = torch.transpose(output_mem, 0, 1)

        return lstm_out, output_mem
示例#7
0
 def forward(self, x_self: torch.Tensor,
             entities: torch.Tensor) -> torch.Tensor:
     if self.self_size > 0:
         num_entities = self.entity_num_max_elements
         if num_entities < 0:
             if exporting_to_onnx.is_exporting():
                 raise UnityTrainerException(
                     "Trying to export an attention mechanism that doesn't have a set max \
                     number of elements.")
             num_entities = entities.shape[1]
         expanded_self = x_self.reshape(-1, 1, self.self_size)
         expanded_self = torch.cat([expanded_self] * num_entities, dim=1)
         # Concatenate all observations with self
         entities = torch.cat([expanded_self, entities], dim=2)
     # Encode entities
     encoded_entities = self.self_ent_encoder(entities)
     return encoded_entities
示例#8
0
    def forward(
        self,
        vec_inputs: List[torch.Tensor],
        vis_inputs: List[torch.Tensor],
        actions: Optional[torch.Tensor] = None,
        memories: Optional[torch.Tensor] = None,
        sequence_length: int = 1,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        encodes = []
        for idx, processor in enumerate(self.vector_processors):
            vec_input = vec_inputs[idx]
            processed_vec = processor(vec_input)
            encodes.append(processed_vec)

        for idx, processor in enumerate(self.visual_processors):
            vis_input = vis_inputs[idx]
            if not exporting_to_onnx.is_exporting():
                vis_input = vis_input.permute([0, 3, 1, 2])
            processed_vis = processor(vis_input)
            encodes.append(processed_vis)

        if len(encodes) == 0:
            raise Exception("No valid inputs to network.")

        # Constants don't work in Barracuda
        if actions is not None:
            inputs = torch.cat(encodes + [actions], dim=-1)
        else:
            inputs = torch.cat(encodes, dim=-1)
        encoding = self.linear_encoder(inputs)

        if self.use_lstm:
            # Resize to (batch, sequence length, encoding size)
            encoding = encoding.reshape([-1, sequence_length, self.h_size])
            encoding, memories = self.lstm(encoding, memories)
            encoding = encoding.reshape([-1, self.m_size // 2])
        return encoding, memories
示例#9
0
 def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
     if not exporting_to_onnx.is_exporting():
         visual_obs = visual_obs.permute([0, 3, 1, 2])
     hidden = self.sequential(visual_obs)
     before_out = hidden.reshape(-1, self.final_flat_size)
     return torch.relu(self.dense(before_out))
示例#10
0
 def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
     if not exporting_to_onnx.is_exporting():
         visual_obs = visual_obs.permute([0, 3, 1, 2])
     hidden = self.conv_layers(visual_obs)
     hidden = hidden.reshape([-1, self.final_flat])
     return self.dense(hidden)
示例#11
0
 def forward(self, visual_obs: torch.Tensor) -> torch.Tensor:
     if not exporting_to_onnx.is_exporting():
         visual_obs = visual_obs.permute([0, 3, 1, 2])
     hidden = visual_obs.reshape(-1, self.input_size)
     return self.dense(hidden)