Example #1
0
def test_draw_no_segmention_mask():
    img = torch.full((3, 100, 100), 0, dtype=torch.uint8)
    masks = torch.full((0, 100, 100), 0, dtype=torch.bool)
    with pytest.warns(
            UserWarning,
            match=re.escape(
                "masks doesn't contain any mask. No mask was drawn")):
        res = utils.draw_segmentation_masks(img, masks)
        # Check that the function didnt change the image
        assert res.eq(img).all()
Example #2
0
    def test_draw_segmentation_masks_no_colors(self):
        img = torch.full((3, 20, 20), 255, dtype=torch.uint8)
        result = utils.draw_segmentation_masks(img, masks, colors=None)

        path = os.path.join(os.path.dirname(os.path.abspath(__file__)),
                            "assets", "fakedata",
                            "draw_segm_masks_no_colors_util.png")

        if not os.path.exists(path):
            res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
            res.save(path)

        expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
        self.assertTrue(torch.equal(result, expected))
Example #3
0
def test_draw_segmentation_masks(colors, alpha):
    """This test makes sure that masks draw their corresponding color where they should"""
    num_masks, h, w = 2, 100, 100
    dtype = torch.uint8
    img = torch.randint(0, 256, size=(3, h, w), dtype=dtype)
    masks = torch.randint(0, 2, (num_masks, h, w), dtype=torch.bool)

    # For testing we enforce that there's no overlap between the masks. The
    # current behaviour is that the last mask's color will take priority when
    # masks overlap, but this makes testing slightly harder so we don't really
    # care
    overlap = masks[0] & masks[1]
    masks[:, overlap] = False

    out = utils.draw_segmentation_masks(img, masks, colors=colors, alpha=alpha)
    assert out.dtype == dtype
    assert out is not img

    # Make sure the image didn't change where there's no mask
    masked_pixels = masks[0] | masks[1]
    assert_equal(img[:, ~masked_pixels], out[:, ~masked_pixels])

    if colors is None:
        colors = utils._generate_color_palette(num_masks)
    elif isinstance(colors, str) or isinstance(colors, tuple):
        colors = [colors]

    # Make sure each mask draws with its own color
    for mask, color in zip(masks, colors):
        if isinstance(color, str):
            color = ImageColor.getrgb(color)
        color = torch.tensor(color, dtype=dtype)

        if alpha == 1:
            assert (out[:, mask] == color[:, None]).all()
        elif alpha == 0:
            assert (out[:, mask] == img[:, mask]).all()

        interpolated_color = (img[:, mask] * (1 - alpha) +
                              color[:, None] * alpha).to(dtype)
        torch.testing.assert_close(out[:, mask],
                                   interpolated_color,
                                   rtol=0.0,
                                   atol=1.0)
Example #4
0
    def test_draw_segmentation_masks_colors(self):
        img = torch.full((3, 5, 5), 255, dtype=torch.uint8)
        img_cp = img.clone()
        masks_cp = masks.clone()
        colors = ["#FF00FF", (0, 255, 0), "red"]
        result = utils.draw_segmentation_masks(img, masks, colors=colors)

        path = os.path.join(os.path.dirname(os.path.abspath(__file__)), "assets",
                            "fakedata", "draw_segm_masks_colors_util.png")

        if not os.path.exists(path):
            res = Image.fromarray(result.permute(1, 2, 0).contiguous().numpy())
            res.save(path)

        expected = torch.as_tensor(np.array(Image.open(path))).permute(2, 0, 1)
        self.assertTrue(torch.equal(result, expected))
        # Check if modification is not in place
        self.assertTrue(torch.all(torch.eq(img, img_cp)).item())
        self.assertTrue(torch.all(torch.eq(masks, masks_cp)).item())
Example #5
0
    def visualize_batch(self):
        loader = DataLoader(self, shuffle=True, batch_size=4)

        if self.mode != "test":
            imgs, masks, fnames = next(iter(loader))
        else:
            imgs, fnames = next(iter(loader))

        batch_inputs = F.convert_image_dtype(imgs, dtype=torch.uint8)

        if self.mode != "test":
            batch_outputs = F.convert_image_dtype(masks, dtype=torch.bool)
            list_imgs = [
                draw_segmentation_masks(img,
                                        masks=mask,
                                        alpha=0.6,
                                        colors=(102, 255, 178))
                for img, mask in zip(batch_inputs, batch_outputs)
            ]
        else:
            list_imgs = [imgs[i] for i in range(len(imgs))]

        self.show(list_imgs, fnames)
Example #6
0
# can read it as the following query: "For which pixels is 'dog' the most likely
# class?"
#
# .. note::
#   While we're using the ``normalized_masks`` here, we would have
#   gotten the same result by using the non-normalized scores of the model
#   directly (as the softmax operation preserves the order).
#
# Now that we have boolean masks, we can use them with
# :func:`~torchvision.utils.draw_segmentation_masks` to plot them on top of the
# original images:

from torchvision.utils import draw_segmentation_masks

dogs_with_masks = [
    draw_segmentation_masks(img, masks=mask, alpha=0.7)
    for img, mask in zip(batch_int, boolean_dog_masks)
]
show(dogs_with_masks)

#####################################
# We can plot more than one mask per image! Remember that the model returned as
# many masks as there are classes. Let's ask the same query as above, but this
# time for *all* classes, not just the dog class: "For each pixel and each class
# C, is class C the most most likely class?"
#
# This one is a bit more involved, so we'll first show how to do it with a
# single image, and then we'll generalize to the batch

num_classes = normalized_masks.shape[1]
dog1_masks = normalized_masks[0]
Example #7
0
def test_draw_segmentation_masks_errors():
    h, w = 10, 10

    masks = torch.randint(0, 2, size=(h, w), dtype=torch.bool)
    img = torch.randint(0, 256, size=(3, h, w), dtype=torch.uint8)

    with pytest.raises(TypeError, match="The image must be a tensor"):
        utils.draw_segmentation_masks(image="Not A Tensor Image", masks=masks)
    with pytest.raises(ValueError, match="The image dtype must be"):
        img_bad_dtype = torch.randint(0,
                                      256,
                                      size=(3, h, w),
                                      dtype=torch.int64)
        utils.draw_segmentation_masks(image=img_bad_dtype, masks=masks)
    with pytest.raises(ValueError,
                       match="Pass individual images, not batches"):
        batch = torch.randint(0, 256, size=(10, 3, h, w), dtype=torch.uint8)
        utils.draw_segmentation_masks(image=batch, masks=masks)
    with pytest.raises(ValueError, match="Pass an RGB image"):
        one_channel = torch.randint(0, 256, size=(1, h, w), dtype=torch.uint8)
        utils.draw_segmentation_masks(image=one_channel, masks=masks)
    with pytest.raises(ValueError, match="The masks must be of dtype bool"):
        masks_bad_dtype = torch.randint(0, 2, size=(h, w), dtype=torch.float)
        utils.draw_segmentation_masks(image=img, masks=masks_bad_dtype)
    with pytest.raises(ValueError, match="masks must be of shape"):
        masks_bad_shape = torch.randint(0,
                                        2,
                                        size=(3, 2, h, w),
                                        dtype=torch.bool)
        utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
    with pytest.raises(ValueError,
                       match="must have the same height and width"):
        masks_bad_shape = torch.randint(0,
                                        2,
                                        size=(h + 4, w),
                                        dtype=torch.bool)
        utils.draw_segmentation_masks(image=img, masks=masks_bad_shape)
    with pytest.raises(ValueError, match="There are more masks"):
        utils.draw_segmentation_masks(image=img, masks=masks, colors=[])
    with pytest.raises(
            ValueError,
            match="colors must be a tuple or a string, or a list thereof"):
        bad_colors = np.array(["red", "blue"])  # should be a list
        utils.draw_segmentation_masks(image=img,
                                      masks=masks,
                                      colors=bad_colors)
    with pytest.raises(
            ValueError,
            match="It seems that you passed a tuple of colors instead of"):
        bad_colors = ("red", "blue")  # should be a list
        utils.draw_segmentation_masks(image=img,
                                      masks=masks,
                                      colors=bad_colors)
# transparency of masks.
#
# Here is demo with torchvision's FCN Resnet-50, loaded with
# :func:`~torchvision.models.segmentation.fcn_resnet50`.
# You can also try using
# DeepLabv3 (:func:`~torchvision.models.segmentation.deeplabv3_resnet50`)
# or lraspp mobilenet models
# (:func:`~torchvision.models.segmentation.lraspp_mobilenet_v3_large`).
#
# Like :func:`~torchvision.utils.draw_bounding_boxes`,
# :func:`~torchvision.utils.draw_segmentation_masks` requires a single RGB image
# of dtype `uint8`.

from torchvision.models.segmentation import fcn_resnet50
from torchvision.utils import draw_segmentation_masks

model = fcn_resnet50(pretrained=True, progress=False)
model = model.eval()

# The model expects the batch to be normalized
batch = F.normalize(batch,
                    mean=(0.485, 0.456, 0.406),
                    std=(0.229, 0.224, 0.225))
outputs = model(batch)

dogs_with_masks = [
    draw_segmentation_masks(dog_int, masks=masks, alpha=0.6)
    for dog_int, masks in zip((dog1_int, dog2_int), outputs['out'])
]
show(dogs_with_masks)
# For each instance, the boolean tensors represent if the particular pixel
# belongs to the segmentation mask of the image.

print(masks.size())
print(masks)

####################################
# Let us visualize an image and plot its corresponding segmentation masks.
# We will use the :func:`~torchvision.utils.draw_segmentation_masks` to draw the segmentation masks.

from torchvision.utils import draw_segmentation_masks

drawn_masks = []
for mask in masks:
    drawn_masks.append(
        draw_segmentation_masks(img, mask, alpha=0.8, colors="blue"))

show(drawn_masks)

####################################
# To convert the boolean masks into bounding boxes.
# We will use the :func:`~torchvision.ops.masks_to_boxes` from the torchvision.ops module
# It returns the boxes in ``(xmin, ymin, xmax, ymax)`` format.

from torchvision.ops import masks_to_boxes

boxes = masks_to_boxes(masks)
print(boxes.size())
print(boxes)

####################################