def test_center_crop(self): img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) cropped_tensor = F_t.center_crop(img_tensor, [10, 10]) cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10]) cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8) self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor))
def test_center_crop(self): script_center_crop = torch.jit.script(F_t.center_crop) img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) img_tensor_clone = img_tensor.clone() cropped_tensor = F_t.center_crop(img_tensor, [10, 10]) cropped_pil_image = F.center_crop(transforms.ToPILImage()(img_tensor), [10, 10]) cropped_pil_tensor = (transforms.ToTensor()(cropped_pil_image) * 255).to(torch.uint8) self.assertTrue(torch.equal(cropped_tensor, cropped_pil_tensor)) self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) # scriptable function test cropped_script = script_center_crop(img_tensor, [10, 10]) self.assertTrue(torch.equal(cropped_script, cropped_tensor))