コード例 #1
0
 def test_center_crop(self):
     img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
     cropped_tensor = F_t.center_crop(img_tensor, [10, 10])
     cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor),
                                       [10, 10])
     cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) *
                           255).to(torch.uint8)
     self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor))
コード例 #2
0
 def test_center_crop(self):
     script_center_crop = torch.jit.script(F_t.center_crop)
     img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8)
     img_tensor_clone = img_tensor.clone()
     cropped_tensor = F_t.center_crop(img_tensor, [10, 10])
     cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10])
     cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8)
     self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor))
     self.assertTrue(torch.equal(img_tensor, img_tensor_clone))
     # scriptable function test
     cropped_script = script_center_crop(img_tensor, [10, 10])
     self.assertTrue(torch.equal(cropped_script, cropped_tensor))