Ejemplo n.º 1
0
    def test_cat(self):
        a = torch.arange(12).reshape(3, 4)
        b = a + 12
        c = b + 12

        nt0 = ntnt([a, b])
        nt1 = ntnt([c])
        self.assertEqual(torch.cat([nt0, nt1], dim=0), ntnt_nograd([a, b, c]))
        self.assertEqual(nestedtensor.cat([nt0, nt1], dim=1),
                         ntnt_nograd([torch.cat([a, c]), b]))
        self.assertEqual(nestedtensor.cat([nt0, nt1], dim=2),
                         ntnt_nograd([torch.cat([a, c], dim=1), b]))
Ejemplo n.º 2
0
 def forward(self, tensor_list):
     not_mask = []
     for tensor in tensor_list:
         not_mask.append((torch.ones_like(tensor,
                                          dtype=torch.bool).prod(0)).bool())
     not_mask = nestedtensor.nested_tensor(not_mask,
                                           dtype=torch.bool,
                                           device=tensor.device)
     y_embed = not_mask.cumsum(1, dtype=torch.float32)
     x_embed = not_mask.cumsum(2, dtype=torch.float32)
     if self.normalize:
         eps = 1e-6
         y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
         x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
     dim_t = torch.arange(self.num_pos_feats)
     dim_t = nestedtensor.nested_tensor(len(tensor_list) * [dim_t],
                                        dtype=torch.float32,
                                        device=tensor.device)
     dim_t = self.temperature**(2 * (dim_t // 2) / self.num_pos_feats)
     pos_x = x_embed[:, :, :, None] / dim_t
     pos_y = y_embed[:, :, :, None] / dim_t
     pos = nestedtensor.cat((pos_y, pos_x), dim=3)
     pos_sin = pos[:, :, :, 0::2].sin()
     pos_cos = pos[:, :, :, 1::2].cos()
     res = nestedtensor.stack((pos_sin, pos_cos), dim=4)
     return res.flatten(3)