def test_ten_crop(self):
     script_ten_crop = torch.jit.script(F_t.ten_crop)
     img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
     img_tensor_clone = img_tensor.clone()
     cropped_tensor = F_t.ten_crop(img_tensor, [10, 10])
     cropped_pil_image = F.ten_crop(transforms.ToPILImage()(img_tensor), [10, 10])
     self.assertTrue(torch.equal(cropped_tensor[0],
                                 (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8)))
     self.assertTrue(torch.equal(cropped_tensor[1],
                                 (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8)))
     self.assertTrue(torch.equal(cropped_tensor[2],
                                 (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8)))
     self.assertTrue(torch.equal(cropped_tensor[3],
                                 (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8)))
     self.assertTrue(torch.equal(cropped_tensor[4],
                                 (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8)))
     self.assertTrue(torch.equal(cropped_tensor[5],
                                 (transforms.ToTensor()(cropped_pil_image[5]) * 255).to(torch.uint8)))
     self.assertTrue(torch.equal(cropped_tensor[6],
                                 (transforms.ToTensor()(cropped_pil_image[7]) * 255).to(torch.uint8)))
     self.assertTrue(torch.equal(cropped_tensor[7],
                                 (transforms.ToTensor()(cropped_pil_image[6]) * 255).to(torch.uint8)))
     self.assertTrue(torch.equal(cropped_tensor[8],
                                 (transforms.ToTensor()(cropped_pil_image[8]) * 255).to(torch.uint8)))
     self.assertTrue(torch.equal(cropped_tensor[9],
                                 (transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8)))
     self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
     # scriptable function test
     cropped_script = script_ten_crop(img_tensor, [10, 10])
     for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor):
         self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
Exemplo n.º 2
0
 def test_ten_crop(self):
     img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
     img_tensor_clone = img_tensor.clone()
     cropped_tensor = F_t.ten_crop(img_tensor, [10, 10])
     cropped_pil_image = F.ten_crop(transforms.ToPILImage()(img_tensor),
                                    [10, 10])
     self.assertTrue(
         torch.equal(cropped_tensor[0],
                     (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[1],
                     (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[2],
                     (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[3],
                     (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[4],
                     (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[5],
                     (transforms.ToTensor()(cropped_pil_image[5]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[6],
                     (transforms.ToTensor()(cropped_pil_image[7]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[7],
                     (transforms.ToTensor()(cropped_pil_image[6]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[8],
                     (transforms.ToTensor()(cropped_pil_image[8]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[9],
                     (transforms.ToTensor()(cropped_pil_image[9]) * 255).to(
                         torch.uint8)))
     self.assertTrue(torch.equal(img_tensor, img_tensor_clone))