def multi_bernoulli_activated_equality(xz, yz, az): """ Compute the bitwise log probability that two Multi-Bernoulli are equal or that a third Multi-Bernoulli is one. Parameters ---------- xz : torch.tensor the logits (before sigmoid) of the first Multi-Bernoulli yz : torch.tensor the logits (before sigmoid) of the second Multi-Bernoulli az : torch.tensor the logits of the third Multi-Bernoulli which act as an activation of the equality. Returns ------- log_p0 : torch.tensor the bitwise log probability that the two Multi-Bernoulli are not equal and the third is zero. log_p1 : torch.tensor the bitwise log probability that the two Multi-Bernoulli are equal or the third is one. Notes ----- xz and yz need not to have the same shape, but they should be broadcastable. """ xp, yp, ap, xn, yn, an = map(logsigmoid, (xz, yz, az, -xz, -yz, -az)) log_p0 = torch_logsumexp(an + xp + yn, an + xn + yp) log_p1 = torch_logsumexp(ap, an + xp + yp, an + xn + yn) return log_p0, log_p1
def multi_bernoulli_equality(xz, yz): """ Compute the bitwise log probability that two Multi-Bernoulli are equal. Parameters ---------- xz : torch.tensor the logits (before sigmoid) of the first Multi-Bernoulli yz : torch.tensor the logits (before sigmoid) of the second Multi-Bernoulli Returns ------- log_p0 : torch.tensor the bitwise log probability that the two Multi-Bernoulli are not equal log_p1 : torch.tensor the bitwise log probability that the two Multi-Bernoulli are equal Notes ----- xz and yz need not to have the same shape, but they should be broadcastable. """ xp, yp, xn, yn = map(logsigmoid, (xz, yz, -xz, -yz)) log_p0 = torch_logsumexp(xp + yn, xn + yp) log_p1 = torch_logsumexp(xp + yp, xn + yn) return log_p0, log_p1
def log_hamming_binomial(log_p10, log_p11, log_p20, log_p21): """ Computes the log probabilities of each Hamming Binomial events, parameterized with p1 and p2. Parameters ---------- log_p10 : torch.tensor (dtype=torch.float) The log probability of each bits to be zero of the first random vector. The Hamming Binomial is considered to be on the last dim. shape=(a1,a2,a3,...,am,n) where n is the number of bits for each vectors. a1,a2,a3,...,am are arbitrary but should be broadcastable with the other inputs. log_p11 : torch.tensor (dtype=torch.float) The log probability of each bits to be one of the first random vector. The Hamming Binomial is considered to be on the last dim. shape=(a1,a2,a3,...,am,n) where n is the number of bits for each vectors. a1,a2,a3,...,am are arbitrary but should be broadcastable with the other inputs. log_p20 : torch.tensor (dtype=torch.float) The log probability of each bits to be zero of the second random vector. The Hamming Binomial is considered to be on the last dim. shape=(a1,a2,a3,...,am,n) where n is the number of bits for each vectors. a1,a2,a3,...,am are arbitrary but should be broadcastable with the other inputs. log_p21 : torch.tensor (dtype=torch.float) The log probability of each bits to be one of the second random vector. The Hamming Binomial is considered to be on the last dim. shape=(a1,a2,a3,...,am,n) where n is the number of bits for each vectors. a1,a2,a3,...,am are arbitrary but should be broadcastable with the other inputs. Returns ------- log_hb : torch.tensor (dtype=torch.float) The log probability of each Hamming Binomial events. shape=(a1,a2,a3,...,am,n+1). log_pb[i1,i2,i3,...,am,k] is the log probability that the hamming distance between the two random Multi-Bernoulli parameterized by log_p11.exp()[i1,i2,i3,...,am] and log_p21.exp()[i1,i2,i3,...,am] be k. Notes ----- see log_poisson_binomial's notes. """ log_q0 = torch_logsumexp(log_p11 + log_p21, log_p10 + log_p20) log_q1 = torch_logsumexp(log_p11 + log_p20, log_p10 + log_p21) return log_poisson_binomial(log_q0, log_q1)
def _log_poisson_binomial(log_q0, log_q1): bs, n = log_q0.shape dtype, device = log_q0.dtype, log_q0.device ninf = torch.tensor(-np.inf, dtype=dtype, device=device).expand(1, bs) z = torch.zeros((1, bs), dtype=dtype, device=device) for n,(a,b) in enumerate(zip(log_q0.permute(1,0), log_q1.permute(1,0))): w0 = torch.cat([a + z, ninf], dim=0) w1 = torch.cat([ninf, b + z], dim=0) z = torch_logsumexp(w0, w1) return z.permute(1,0)