def test_stack(self): a = torch.arange(12).reshape(3, 4) b = a + 12 c = b + 12 nt = nestedtensor.nested_tensor([[a, b], [c]]) nt0 = nestedtensor.nested_tensor([a, b]) nt1 = nestedtensor.nested_tensor([c]) self.assertEqual(nestedtensor.stack( [nt0, nt1], dim=0), ntnt([[a, b], [c]])) self.assertEqual(nestedtensor.stack( [nt0, nt1], dim=1), ntnt([torch.stack([a, c]), b.reshape(1, 3, 4)])) self.assertEqual(nestedtensor.stack( [nt0, nt1], dim=2), ntnt([torch.stack([a, c], dim=1), b.reshape(3, 1, 4)]))
def forward(self, tensor_list): not_mask = [] for tensor in tensor_list: not_mask.append((torch.ones_like(tensor, dtype=torch.bool).prod(0)).bool()) not_mask = nestedtensor.nested_tensor(not_mask, dtype=torch.bool, device=tensor.device) y_embed = not_mask.cumsum(1, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats) dim_t = nestedtensor.nested_tensor(len(tensor_list) * [dim_t], dtype=torch.float32, device=tensor.device) dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos = nestedtensor.cat((pos_y, pos_x), dim=3) pos_sin = pos[:, :, :, 0::2].sin() pos_cos = pos[:, :, :, 1::2].cos() res = nestedtensor.stack((pos_sin, pos_cos), dim=4) return res.flatten(3)