Example #1
0
    def test_color_image_channel_size(self, img_size, color, dtype, device):
        img = torch.zeros(*img_size, dtype=dtype, device=device)
        with pytest.raises(ValueError) as excinfo:
            draw_line(img, torch.tensor([0, 0]), torch.tensor([1, 1]), color)

        assert 'color must have the same number of channels as the image.' == str(
            excinfo.value)
Example #2
0
    def test_p2_out_of_bounds(self, p2, dtype, device):
        """Tests that an exception is raised if p2 is out of bounds."""
        img = torch.zeros(1, 8, 8, dtype=dtype, device=device)
        with pytest.raises(ValueError) as excinfo:
            draw_line(img, torch.tensor([0, 0]), p2, torch.tensor([255]))

        assert 'p2 is out of bounds.' == str(excinfo.value)
Example #3
0
    def test_image_size(self, img_size, dtype, device):
        img = torch.zeros(*img_size, dtype=dtype, device=device)
        with pytest.raises(ValueError) as excinfo:
            draw_line(img, torch.tensor([0, 0]), torch.tensor([1, 1]),
                      torch.tensor([255]))

        assert 'image must have 3 dimensions (C,H,W).' == str(excinfo.value)
Example #4
0
 def test_draw_line_with_big_coordinates(self, dtype, device):
     """Test drawing a line with big coordinates."""
     img = torch.zeros(1, 500, 500, dtype=dtype, device=device)
     img = draw_line(img, torch.tensor([200, 200]), torch.tensor([400,
                                                                  200]),
                     torch.tensor([255]))
     img_mask = torch.zeros(1, 500, 500, dtype=dtype, device=device)
     img_mask[:, 200, 200:401] = 255
     assert_close(img, img_mask)
Example #5
0
 def test_draw_line_vertical(self, dtype, device):
     """Test drawing a vertical line."""
     img = torch.zeros(1, 8, 8, dtype=dtype, device=device)
     img = draw_line(img, torch.tensor([6, 2]), torch.tensor([6, 0]), torch.tensor([255]))
     img_mask = img == torch.tensor([[0., 0., 0., 0., 0., 0., 255., 0.],
                                     [0., 0., 0., 0., 0., 0., 255., 0.],
                                     [0., 0., 0., 0., 0., 0., 255., 0.],
                                     [0., 0., 0., 0., 0., 0., 0., 0.],
                                     [0., 0., 0., 0., 0., 0., 0., 0.],
                                     [0., 0., 0., 0., 0., 0., 0., 0.],
                                     [0., 0., 0., 0., 0., 0., 0., 0.],
                                     [0., 0., 0., 0., 0., 0., 0., 0.]])
     assert torch.all(img_mask)
Example #6
0
 def test_draw_line_m_gte_1(self, dtype, device):
     """Test drawing a line with m >= 1."""
     img = torch.zeros(1, 8, 8, dtype=dtype, device=device)
     img = draw_line(img, torch.tensor([3, 7]), torch.tensor([1, 4]), torch.tensor([255]))
     img_mask = img == torch.tensor([[0., 0., 0., 0., 0., 0., 0., 0.],
                                     [0., 0., 0., 0., 0., 0., 0., 0.],
                                     [0., 0., 0., 0., 0., 0., 0., 0.],
                                     [0., 0., 0., 0., 0., 0., 0., 0.],
                                     [0., 255., 0., 0., 0., 0., 0., 0.],
                                     [0., 0., 255., 0., 0., 0., 0., 0.],
                                     [0., 0., 255., 0., 0., 0., 0., 0.],
                                     [0., 0., 0., 255., 0., 0., 0., 0.]])
     assert torch.all(img_mask)
Example #7
0
 def test_draw_line_m_gt_0_lt_1(self, dtype, device):
     """Test drawing a line with 0 < m < 1."""
     img = torch.zeros(1, 8, 8, dtype=dtype, device=device)
     img = draw_line(img, torch.tensor([0, 0]), torch.tensor([6, 2]), torch.tensor([255]))
     img_mask = img == torch.tensor([[255., 255., 0., 0., 0., 0., 0., 0.],
                                     [0., 0., 255., 255., 255., 0., 0., 0.],
                                     [0., 0., 0., 0., 0., 255., 255., 0.],
                                     [0., 0., 0., 0., 0., 0., 0., 0.],
                                     [0., 0., 0., 0., 0., 0., 0., 0.],
                                     [0., 0., 0., 0., 0., 0., 0., 0.],
                                     [0., 0., 0., 0., 0., 0., 0., 0.],
                                     [0., 0., 0., 0., 0., 0., 0., 0.]])
     assert torch.all(img_mask)
Example #8
0
 def test_draw_line_m_lt_0_gte_neg1(self, dtype, device):
     """Test drawing a line with -1 < m < 0."""
     img = torch.zeros(1, 8, 8, dtype=dtype, device=device)
     img = draw_line(img, torch.tensor([1, 5]), torch.tensor([7, 0]), torch.tensor([255]))
     img_mask = img == torch.tensor([[0., 0., 0., 0., 0., 0., 0., 255.],
                                     [0., 0., 0., 0., 0., 0., 255., 0.],
                                     [0., 0., 0., 0., 0., 255., 0., 0.],
                                     [0., 0., 0., 255., 255., 0., 0., 0.],
                                     [0., 0., 255., 0., 0., 0., 0., 0.],
                                     [0., 255., 0., 0., 0., 0., 0., 0.],
                                     [0., 0., 0., 0., 0., 0., 0., 0.],
                                     [0., 0., 0., 0., 0., 0., 0., 0.]])
     assert torch.all(img_mask)
Example #9
0
 def test_draw_line_m_lte_neg1(self, dtype, device):
     """Test drawing a line with m <= -1."""
     img = torch.zeros(1, 8, 8, dtype=dtype, device=device)
     img = draw_line(img, torch.tensor([0, 7]), torch.tensor([6, 0]), torch.tensor([255]))
     img_mask = torch.tensor([[
         [0., 0., 0., 0., 0., 0., 255., 0.],
         [0., 0., 0., 0., 0., 255., 0., 0.],
         [0., 0., 0., 0., 255., 0., 0., 0.],
         [0., 0., 0., 255., 0., 0., 0., 0.],
         [0., 0., 0., 255., 0., 0., 0., 0.],
         [0., 0., 255., 0., 0., 0., 0., 0.],
         [0., 255., 0., 0., 0., 0., 0., 0.],
         [255., 0., 0., 0., 0., 0., 0., 0.],
     ]], device=device, dtype=dtype)
     assert_close(img, img_mask)
Example #10
0
 def test_draw_line_horizontal(self, dtype, device):
     """Test drawing a horizontal line."""
     img = torch.zeros(1, 8, 8, dtype=dtype, device=device)
     img = draw_line(img, torch.tensor([6, 4]), torch.tensor([0, 4]), torch.tensor([255]))
     img_mask = torch.tensor([[
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [255., 255., 255., 255., 255., 255., 255., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0., 0., 0., 0.],
     ]], device=device, dtype=dtype)
     assert_close(img, img_mask)
Example #11
0
 def test_draw_line_m_lt_0_gte_neg1(self, dtype, device):
     """Test drawing a line with -1 < m < 0."""
     img = torch.zeros(1, 8, 8, dtype=dtype, device=device)
     img = draw_line(img, torch.tensor([1, 5]), torch.tensor([7, 0]),
                     torch.tensor([255]))
     img_mask = torch.tensor(
         [[
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 255.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 255.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 255.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 255.0, 255.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 255.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 255.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
         ]],
         device=device,
         dtype=dtype,
     )
     assert_close(img, img_mask)
Example #12
0
 def test_draw_line_m_gte_1(self, dtype, device):
     """Test drawing a line with m >= 1."""
     img = torch.zeros(1, 8, 8, dtype=dtype, device=device)
     img = draw_line(img, torch.tensor([3, 7]), torch.tensor([1, 4]),
                     torch.tensor([255]))
     img_mask = torch.tensor(
         [[
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 255.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 255.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 255.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 255.0, 0.0, 0.0, 0.0, 0.0],
         ]],
         device=device,
         dtype=dtype,
     )
     assert_close(img, img_mask)
Example #13
0
 def test_draw_line_vertical(self, dtype, device):
     """Test drawing a vertical line."""
     img = torch.zeros(1, 8, 8, dtype=dtype, device=device)
     img = draw_line(img, torch.tensor([6, 2]), torch.tensor([6, 0]),
                     torch.tensor([255]))
     img_mask = torch.tensor(
         [[
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 255.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 255.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 255.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
         ]],
         device=device,
         dtype=dtype,
     )
     assert_close(img, img_mask)
Example #14
0
 def test_draw_line_m_gt_0_lt_1(self, dtype, device):
     """Test drawing a line with 0 < m < 1."""
     img = torch.zeros(1, 8, 8, dtype=dtype, device=device)
     img = draw_line(img, torch.tensor([0, 0]), torch.tensor([6, 2]),
                     torch.tensor([255]))
     img_mask = torch.tensor(
         [[
             [255.0, 255.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 255.0, 255.0, 255.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 255.0, 255.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
             [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
         ]],
         device=device,
         dtype=dtype,
     )
     assert_close(img, img_mask)
Example #15
0
    def test_point_size(self, p1, p2, dtype, device):
        img = torch.zeros(1, 8, 8, dtype=dtype, device=device)
        with pytest.raises(ValueError) as excinfo:
            draw_line(img, p1, p2, torch.tensor([255]))

        assert 'p1 and p2 must have length 2.' == str(excinfo.value)