Beispiel #1
0
    def test_rectangle_erasing2(self, device):
        inputs = torch.ones(3, 3, 3, 3).to(device)
        rect_params = {
            "widths": torch.tensor([3, 2, 1]),
            "heights": torch.tensor([3, 2, 1]),
            "xs": torch.tensor([0, 1, 2]),
            "ys": torch.tensor([0, 1, 2]),
            "values": torch.tensor([0.0, 0.0, 0.0]),
        }
        expected = torch.tensor(
            [
                [
                    [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
                    [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
                    [[0.0, 0.0, 0.0], [0.0, 0.0, 0.0], [0.0, 0.0, 0.0]],
                ],
                [
                    [[1.0, 1.0, 1.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0]],
                    [[1.0, 1.0, 1.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0]],
                    [[1.0, 1.0, 1.0], [1.0, 0.0, 0.0], [1.0, 0.0, 0.0]],
                ],
                [
                    [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 0.0]],
                    [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 0.0]],
                    [[1.0, 1.0, 1.0], [1.0, 1.0, 1.0], [1.0, 1.0, 0.0]],
                ],
            ]
        ).to(device)

        assert_allclose(F.apply_erase_rectangles(inputs, rect_params), expected)
Beispiel #2
0
 def test_rectangle_erasing1(self, device):
     inputs = torch.ones(1, 1, 10, 10).to(device)
     rect_params = {
         "widths": torch.tensor([5]),
         "heights": torch.tensor([5]),
         "xs": torch.tensor([5]),
         "ys": torch.tensor([5]),
         "values": torch.tensor([0.])
     }
     expected = torch.tensor([[[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
                                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
                                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
                                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
                                [1., 1., 1., 1., 1., 1., 1., 1., 1., 1.],
                                [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
                                [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
                                [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
                                [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
                                [1., 1., 1., 1., 1., 0., 0., 0., 0.,
                                 0.]]]]).to(device)
     assert_allclose(F.apply_erase_rectangles(inputs, rect_params),
                     expected)