def test_normalize(self): image = torch.tensor([0, 255, 255], dtype=torch.uint8) hist, bin_centers = histogram(image, source_range='dtype', normalize=False) expected = torch.zeros(256, dtype=torch.long) expected[0] = 1 expected[-1] = 2 self.assertTrue(torch.equal(hist, expected)) hist, bin_centers = histogram(image, source_range='dtype', normalize=True) expected = expected.float().div(3.0) self.assertTrue(torch.equal(hist, expected))
def test_negative_image(self): image = torch.tensor([-100, -1], dtype=torch.int8) hist, bin_centers = histogram(image) self.assertTrue(torch.equal(bin_centers, torch.arange(-100, 0))) self.assertTrue(hist[0] == 1) self.assertTrue(hist[-1] == 1) self.assertTrue((hist[1:-1] == 0).all())
def test_peak_float_out_of_range_image(self): image = torch.tensor([10, 100], dtype=torch.float) hist, bin_centers = histogram(image, nbins=90) # offset values by 0.5 for float... self.assertTrue( torch.equal(bin_centers, torch.arange(10, 100).float().add(0.5)))
def test_peak_float_out_of_range_dtype(self): image = torch.tensor([10, 100], dtype=torch.float) hist, bin_centers = histogram(image, nbins=10, source_range='dtype') self.assertTrue( torch.allclose(torch.min(bin_centers), torch.tensor(-0.9))) self.assertTrue( torch.allclose(torch.max(bin_centers), torch.tensor(0.9))) self.assertEqual(len(bin_centers), 10)
def test_peak_int_range_dtype(self): image = torch.tensor([10, 100], dtype=torch.int8) hist, bin_centers = histogram(image, source_range='dtype') self.assertTrue(torch.equal(bin_centers, torch.arange(-128, 128))) self.assertEqual(hist[128 + 10], 1) self.assertEqual(hist[128 + 100], 1) self.assertEqual(hist[128 + 101], 0) self.assertEqual(hist.size(), torch.Size([256]))
def histogram_demo(): image = torch.tensor(data.camera()) f, (ax1, ax2) = plt.subplots(1, 2) ax1.imshow(image) ax1.set_title("original image") ax1.get_xaxis().set_visible(False) ax1.get_yaxis().set_visible(False) hist, bins = histogram(image) ax2.set_title("histogram of image") ax2.hist(hist, bins=bins) f.show()
def test_flat_int_range_dtype(self): image = torch.linspace(-128, 128, 256).type(torch.int8) hist, bin_centers = histogram(image, source_range='dtype') self.assertTrue(torch.equal(bin_centers, torch.arange(-128, 128))) self.assertEqual(hist.size(), torch.Size([256]))
def test_peak_uint_range_dtype(self): image = torch.tensor([10, 100], dtype=torch.int8) hist, bin_centers = histogram(image) self.assertEqual(len(hist), len(bin_centers)) self.assertEqual(bin_centers[0], 10) self.assertEqual(bin_centers[-1], 100)
def test_wrong_source_range(self): image = torch.tensor([-1, 100], dtype=torch.int8) with self.assertRaises(ValueError): _, _ = histogram(image, source_range='foobar')
def test_input_tensor(self): image = [10, 100] with self.assertRaises(TypeError): _, _ = histogram(image)