def test_cat(self): a = torch.arange(12).reshape(3, 4) b = a + 12 c = b + 12 nt0 = ntnt([a, b]) nt1 = ntnt([c]) self.assertEqual(torch.cat([nt0, nt1], dim=0), ntnt_nograd([a, b, c])) self.assertEqual(nestedtensor.cat([nt0, nt1], dim=1), ntnt_nograd([torch.cat([a, c]), b])) self.assertEqual(nestedtensor.cat([nt0, nt1], dim=2), ntnt_nograd([torch.cat([a, c], dim=1), b]))
def forward(self, tensor_list): not_mask = [] for tensor in tensor_list: not_mask.append((torch.ones_like(tensor, dtype=torch.bool).prod(0)).bool()) not_mask = nestedtensor.nested_tensor(not_mask, dtype=torch.bool, device=tensor.device) y_embed = not_mask.cumsum(1, dtype=torch.float32) x_embed = not_mask.cumsum(2, dtype=torch.float32) if self.normalize: eps = 1e-6 y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale dim_t = torch.arange(self.num_pos_feats) dim_t = nestedtensor.nested_tensor(len(tensor_list) * [dim_t], dtype=torch.float32, device=tensor.device) dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats) pos_x = x_embed[:, :, :, None] / dim_t pos_y = y_embed[:, :, :, None] / dim_t pos = nestedtensor.cat((pos_y, pos_x), dim=3) pos_sin = pos[:, :, :, 0::2].sin() pos_cos = pos[:, :, :, 1::2].cos() res = nestedtensor.stack((pos_sin, pos_cos), dim=4) return res.flatten(3)