def test_ten_crop(self): script_ten_crop = torch.jit.script(F_t.ten_crop) img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) img_tensor_clone = img_tensor.clone() cropped_tensor = F_t.ten_crop(img_tensor, [10, 10]) cropped_pil_image = F.ten_crop(transforms.ToPILImage()(img_tensor), [10, 10]) self.assertTrue(torch.equal(cropped_tensor[0], (transforms.ToTensor()(cropped_pil_image[0]) * 255).to(torch.uint8))) self.assertTrue(torch.equal(cropped_tensor[1], (transforms.ToTensor()(cropped_pil_image[2]) * 255).to(torch.uint8))) self.assertTrue(torch.equal(cropped_tensor[2], (transforms.ToTensor()(cropped_pil_image[1]) * 255).to(torch.uint8))) self.assertTrue(torch.equal(cropped_tensor[3], (transforms.ToTensor()(cropped_pil_image[3]) * 255).to(torch.uint8))) self.assertTrue(torch.equal(cropped_tensor[4], (transforms.ToTensor()(cropped_pil_image[4]) * 255).to(torch.uint8))) self.assertTrue(torch.equal(cropped_tensor[5], (transforms.ToTensor()(cropped_pil_image[5]) * 255).to(torch.uint8))) self.assertTrue(torch.equal(cropped_tensor[6], (transforms.ToTensor()(cropped_pil_image[7]) * 255).to(torch.uint8))) self.assertTrue(torch.equal(cropped_tensor[7], (transforms.ToTensor()(cropped_pil_image[6]) * 255).to(torch.uint8))) self.assertTrue(torch.equal(cropped_tensor[8], (transforms.ToTensor()(cropped_pil_image[8]) * 255).to(torch.uint8))) self.assertTrue(torch.equal(cropped_tensor[9], (transforms.ToTensor()(cropped_pil_image[9]) * 255).to(torch.uint8))) self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) # scriptable function test cropped_script = script_ten_crop(img_tensor, [10, 10]) for cropped_script_img, cropped_tensor_img in zip(cropped_script, cropped_tensor): self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
def test_ten_crop(self): img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) img_tensor_clone = img_tensor.clone() cropped_tensor = F_t.ten_crop(img_tensor, [10, 10]) cropped_pil_image = F.ten_crop(transforms.ToPILImage()(img_tensor), [10, 10]) self.assertTrue( torch.equal(cropped_tensor[0], (transforms.ToTensor()(cropped_pil_image[0]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[1], (transforms.ToTensor()(cropped_pil_image[2]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[2], (transforms.ToTensor()(cropped_pil_image[1]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[3], (transforms.ToTensor()(cropped_pil_image[3]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[4], (transforms.ToTensor()(cropped_pil_image[4]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[5], (transforms.ToTensor()(cropped_pil_image[5]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[6], (transforms.ToTensor()(cropped_pil_image[7]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[7], (transforms.ToTensor()(cropped_pil_image[6]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[8], (transforms.ToTensor()(cropped_pil_image[8]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[9], (transforms.ToTensor()(cropped_pil_image[9]) * 255).to( torch.uint8))) self.assertTrue(torch.equal(img_tensor, img_tensor_clone))