def test_get_boxes_shape(self, device, dtype): box = Boxes(torch.tensor([[[1.0, 1.0], [3.0, 2.0], [1.0, 2.0], [3.0, 1.0]]], device=device, dtype=dtype)) t_boxes = torch.tensor( [[[1.0, 1.0], [3.0, 1.0], [1.0, 2.0], [3.0, 2.0]], [[5.0, 4.0], [2.0, 2.0], [5.0, 2.0], [2.0, 4.0]]], device=device, dtype=dtype, ) # (2, 4, 2) boxes = Boxes(t_boxes) boxes_batch = Boxes(t_boxes[None]) # (1, 2, 4, 2) # Single box h, w = box.get_boxes_shape() assert (h.item(), w.item()) == (2, 3) # Boxes h, w = boxes.get_boxes_shape() assert h.ndim == 1 and w.ndim == 1 assert len(h) == 2 and len(w) == 2 assert (h == torch.as_tensor([2.0, 3.0], device=device)).all() and ( w == torch.as_tensor([3.0, 4.0], device=device) ).all() # Box batch h, w = boxes_batch.get_boxes_shape() assert h.ndim == 2 and w.ndim == 2 assert h.shape == (1, 2) and w.shape == (1, 2) assert (h == torch.as_tensor([[2.0, 3.0]], device=device)).all() and ( w == torch.as_tensor([[3.0, 4.0]], device=device) ).all()
def test_get_boxes_shape_batch(self, device, dtype): t_box1 = torch.tensor([[[1.0, 1.0], [3.0, 2.0], [3.0, 1.0], [1.0, 2.0]]], device=device, dtype=dtype) t_box2 = torch.tensor([[[5.0, 2.0], [2.0, 2.0], [5.0, 4.0], [2.0, 4.0]]], device=device, dtype=dtype) batched_boxes = Boxes(torch.stack([t_box1, t_box2])) h, w = batched_boxes.get_boxes_shape() assert h.ndim == 2 and w.ndim == 2 assert h.shape == (2, 1) and w.shape == (2, 1) assert (h == torch.as_tensor([[2], [3]], device=device)).all() and ( w == torch.as_tensor([[3], [4]], device=device) ).all()