def test_color_image_channel_size(self, img_size, color, dtype, device): img = torch.zeros(*img_size, dtype=dtype, device=device) with pytest.raises(ValueError) as excinfo: draw_line(img, torch.tensor([0, 0]), torch.tensor([1, 1]), color) assert 'color must have the same number of channels as the image.' == str( excinfo.value)
def test_p2_out_of_bounds(self, p2, dtype, device): """Tests that an exception is raised if p2 is out of bounds.""" img = torch.zeros(1, 8, 8, dtype=dtype, device=device) with pytest.raises(ValueError) as excinfo: draw_line(img, torch.tensor([0, 0]), p2, torch.tensor([255])) assert 'p2 is out of bounds.' == str(excinfo.value)
def test_image_size(self, img_size, dtype, device): img = torch.zeros(*img_size, dtype=dtype, device=device) with pytest.raises(ValueError) as excinfo: draw_line(img, torch.tensor([0, 0]), torch.tensor([1, 1]), torch.tensor([255])) assert 'image must have 3 dimensions (C,H,W).' == str(excinfo.value)
def test_draw_line_with_big_coordinates(self, dtype, device): """Test drawing a line with big coordinates.""" img = torch.zeros(1, 500, 500, dtype=dtype, device=device) img = draw_line(img, torch.tensor([200, 200]), torch.tensor([400, 200]), torch.tensor([255])) img_mask = torch.zeros(1, 500, 500, dtype=dtype, device=device) img_mask[:, 200, 200:401] = 255 assert_close(img, img_mask)
def test_draw_line_vertical(self, dtype, device): """Test drawing a vertical line.""" img = torch.zeros(1, 8, 8, dtype=dtype, device=device) img = draw_line(img, torch.tensor([6, 2]), torch.tensor([6, 0]), torch.tensor([255])) img_mask = img == torch.tensor([[0., 0., 0., 0., 0., 0., 255., 0.], [0., 0., 0., 0., 0., 0., 255., 0.], [0., 0., 0., 0., 0., 0., 255., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.]]) assert torch.all(img_mask)
def test_draw_line_m_gte_1(self, dtype, device): """Test drawing a line with m >= 1.""" img = torch.zeros(1, 8, 8, dtype=dtype, device=device) img = draw_line(img, torch.tensor([3, 7]), torch.tensor([1, 4]), torch.tensor([255])) img_mask = img == torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 255., 0., 0., 0., 0., 0., 0.], [0., 0., 255., 0., 0., 0., 0., 0.], [0., 0., 255., 0., 0., 0., 0., 0.], [0., 0., 0., 255., 0., 0., 0., 0.]]) assert torch.all(img_mask)
def test_draw_line_m_gt_0_lt_1(self, dtype, device): """Test drawing a line with 0 < m < 1.""" img = torch.zeros(1, 8, 8, dtype=dtype, device=device) img = draw_line(img, torch.tensor([0, 0]), torch.tensor([6, 2]), torch.tensor([255])) img_mask = img == torch.tensor([[255., 255., 0., 0., 0., 0., 0., 0.], [0., 0., 255., 255., 255., 0., 0., 0.], [0., 0., 0., 0., 0., 255., 255., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.]]) assert torch.all(img_mask)
def test_draw_line_m_lt_0_gte_neg1(self, dtype, device): """Test drawing a line with -1 < m < 0.""" img = torch.zeros(1, 8, 8, dtype=dtype, device=device) img = draw_line(img, torch.tensor([1, 5]), torch.tensor([7, 0]), torch.tensor([255])) img_mask = img == torch.tensor([[0., 0., 0., 0., 0., 0., 0., 255.], [0., 0., 0., 0., 0., 0., 255., 0.], [0., 0., 0., 0., 0., 255., 0., 0.], [0., 0., 0., 255., 255., 0., 0., 0.], [0., 0., 255., 0., 0., 0., 0., 0.], [0., 255., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.]]) assert torch.all(img_mask)
def test_draw_line_m_lte_neg1(self, dtype, device): """Test drawing a line with m <= -1.""" img = torch.zeros(1, 8, 8, dtype=dtype, device=device) img = draw_line(img, torch.tensor([0, 7]), torch.tensor([6, 0]), torch.tensor([255])) img_mask = torch.tensor([[ [0., 0., 0., 0., 0., 0., 255., 0.], [0., 0., 0., 0., 0., 255., 0., 0.], [0., 0., 0., 0., 255., 0., 0., 0.], [0., 0., 0., 255., 0., 0., 0., 0.], [0., 0., 0., 255., 0., 0., 0., 0.], [0., 0., 255., 0., 0., 0., 0., 0.], [0., 255., 0., 0., 0., 0., 0., 0.], [255., 0., 0., 0., 0., 0., 0., 0.], ]], device=device, dtype=dtype) assert_close(img, img_mask)
def test_draw_line_horizontal(self, dtype, device): """Test drawing a horizontal line.""" img = torch.zeros(1, 8, 8, dtype=dtype, device=device) img = draw_line(img, torch.tensor([6, 4]), torch.tensor([0, 4]), torch.tensor([255])) img_mask = torch.tensor([[ [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [255., 255., 255., 255., 255., 255., 255., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0.], ]], device=device, dtype=dtype) assert_close(img, img_mask)
def test_draw_line_m_lt_0_gte_neg1(self, dtype, device): """Test drawing a line with -1 < m < 0.""" img = torch.zeros(1, 8, 8, dtype=dtype, device=device) img = draw_line(img, torch.tensor([1, 5]), torch.tensor([7, 0]), torch.tensor([255])) img_mask = torch.tensor( [[ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 255.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 255.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 255.0, 0.0, 0.0], [0.0, 0.0, 0.0, 255.0, 255.0, 0.0, 0.0, 0.0], [0.0, 0.0, 255.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 255.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ]], device=device, dtype=dtype, ) assert_close(img, img_mask)
def test_draw_line_m_gte_1(self, dtype, device): """Test drawing a line with m >= 1.""" img = torch.zeros(1, 8, 8, dtype=dtype, device=device) img = draw_line(img, torch.tensor([3, 7]), torch.tensor([1, 4]), torch.tensor([255])) img_mask = torch.tensor( [[ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 255.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 255.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 255.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 255.0, 0.0, 0.0, 0.0, 0.0], ]], device=device, dtype=dtype, ) assert_close(img, img_mask)
def test_draw_line_vertical(self, dtype, device): """Test drawing a vertical line.""" img = torch.zeros(1, 8, 8, dtype=dtype, device=device) img = draw_line(img, torch.tensor([6, 2]), torch.tensor([6, 0]), torch.tensor([255])) img_mask = torch.tensor( [[ [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 255.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 255.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 255.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ]], device=device, dtype=dtype, ) assert_close(img, img_mask)
def test_draw_line_m_gt_0_lt_1(self, dtype, device): """Test drawing a line with 0 < m < 1.""" img = torch.zeros(1, 8, 8, dtype=dtype, device=device) img = draw_line(img, torch.tensor([0, 0]), torch.tensor([6, 2]), torch.tensor([255])) img_mask = torch.tensor( [[ [255.0, 255.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 255.0, 255.0, 255.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 255.0, 255.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], ]], device=device, dtype=dtype, ) assert_close(img, img_mask)
def test_point_size(self, p1, p2, dtype, device): img = torch.zeros(1, 8, 8, dtype=dtype, device=device) with pytest.raises(ValueError) as excinfo: draw_line(img, p1, p2, torch.tensor([255])) assert 'p1 and p2 must have length 2.' == str(excinfo.value)