def test_five_crop(self): script_five_crop = torch.jit.script(F_t.five_crop) img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) img_tensor_clone = img_tensor.clone() cropped_tensor = F_t.five_crop(img_tensor, [10, 10]) cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor), [10, 10]) self.assertTrue( torch.equal(cropped_tensor[0], (transforms.ToTensor()(cropped_pil_image[0]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[1], (transforms.ToTensor()(cropped_pil_image[2]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[2], (transforms.ToTensor()(cropped_pil_image[1]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[3], (transforms.ToTensor()(cropped_pil_image[3]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[4], (transforms.ToTensor()(cropped_pil_image[4]) * 255).to( torch.uint8))) self.assertTrue(torch.equal(img_tensor, img_tensor_clone)) # scriptable function test cropped_script = script_five_crop(img_tensor, [10, 10]) for cropped_script_img, cropped_tensor_img in zip( cropped_script, cropped_tensor): self.assertTrue(torch.equal(cropped_script_img, cropped_tensor_img))
def test_five_crop(self): img_tensor = torch.randint(0, 255, (1, 32, 32), dtype=torch.uint8) img_tensor_clone = img_tensor.clone() cropped_tensor = F_t.five_crop(img_tensor, [10, 10]) cropped_pil_image = F.five_crop(transforms.ToPILImage()(img_tensor), [10, 10]) self.assertTrue( torch.equal(cropped_tensor[0], (transforms.ToTensor()(cropped_pil_image[0]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[1], (transforms.ToTensor()(cropped_pil_image[2]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[2], (transforms.ToTensor()(cropped_pil_image[1]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[3], (transforms.ToTensor()(cropped_pil_image[3]) * 255).to( torch.uint8))) self.assertTrue( torch.equal(cropped_tensor[4], (transforms.ToTensor()(cropped_pil_image[4]) * 255).to( torch.uint8))) self.assertTrue(torch.equal(img_tensor, img_tensor_clone))