def test_shape_batch(self, device): inp1 = torch.ones(8, 32, device=device) inp2 = torch.ones(8, 32, device=device) bins = torch.linspace(0, 255, 128).to(device) bandwidth = torch.Tensor(np.array(0.9)).to(device) pdf = histogram2d(inp1, inp2, bins, bandwidth) assert pdf.shape == (8, 128, 128)
def test_uniform_dist(self, device): input1 = torch.linspace(0, 255, 10).unsqueeze(0).to(device) input2 = torch.linspace(0, 255, 10).unsqueeze(0).to(device) joint_pdf = histogram2d(input1, input2, torch.linspace(0, 255, 10).to(device), torch.Tensor(np.array(2 * 0.4**2))) ans = torch.eye(10).unsqueeze(0) * 0.1 assert ((ans.cpu() - joint_pdf.cpu()).sum() < 1e-6)