def test_uniform_dist(self, device): input1 = torch.linspace(0, 255, 10).unsqueeze(0).to(device) pdf = histogram(input1, torch.linspace(0, 255, 10).to(device), torch.Tensor(np.array(2 * 0.4**2))) ans = torch.ones((1, 10)) * 0.1 assert ((ans.cpu() - pdf.cpu()).sum() < 1e-6)
def test_shape_batch(self, device): inp = torch.ones(8, 32, device=device) bins = torch.linspace(0, 255, 128).to(device) bandwidth = torch.Tensor(np.array(0.9)).to(device) pdf = histogram(inp, bins, bandwidth) assert pdf.shape == (8, 128)