Esempio n. 1
0
 def test_five_crop(self):
     script_five_crop = torch.jit.script(F_t.five_crop)
     img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
     img_tensor_clone = img_tensor.clone()
     cropped_tensor = F_t.five_crop(img_tensor, [10, 10])
     cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor),
                                     [10, 10])
     self.assertTrue(
         torch.equal(cropped_tensor[0],
                     (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[1],
                     (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[2],
                     (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[3],
                     (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[4],
                     (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(
                         torch.uint8)))
     self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
     # scriptable function test
     cropped_script = script_five_crop(img_tensor, [10, 10])
     for cropped_script_img, cropped_tensor_img in zip(
             cropped_script, cropped_tensor):
         self.assertTrue(torch.equal(cropped_script_img,
                                     cropped_tensor_img))
Esempio n. 2
0
 def test_five_crop(self):
     img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
     img_tensor_clone = img_tensor.clone()
     cropped_tensor = F_t.five_crop(img_tensor, [10, 10])
     cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor),
                                     [10, 10])
     self.assertTrue(
         torch.equal(cropped_tensor[0],
                     (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[1],
                     (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[2],
                     (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[3],
                     (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(
                         torch.uint8)))
     self.assertTrue(
         torch.equal(cropped_tensor[4],
                     (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(
                         torch.uint8)))
     self.assertTrue(torch.equal(img_tensor, img_tensor_clone))