Example #1
0
    def test_uniform_dist(self, device):
        input1 = torch.linspace(0, 255, 10).unsqueeze(0).to(device)

        pdf = histogram(input1,
                        torch.linspace(0, 255, 10).to(device),
                        torch.Tensor(np.array(2 * 0.4**2)))
        ans = torch.ones((1, 10)) * 0.1
        assert ((ans.cpu() - pdf.cpu()).sum() < 1e-6)
Example #2
0
 def test_shape_batch(self, device):
     inp = torch.ones(8, 32, device=device)
     bins = torch.linspace(0, 255, 128).to(device)
     bandwidth = torch.Tensor(np.array(0.9)).to(device)
     pdf = histogram(inp, bins, bandwidth)
     assert pdf.shape == (8, 128)