def multi_bernoulli_activated_equality(xz, yz, az):
    """
    Compute the bitwise log probability that two Multi-Bernoulli are equal
    or that a third Multi-Bernoulli is one.
    
    Parameters
    ----------
    xz : torch.tensor
        the logits (before sigmoid) of the first Multi-Bernoulli
    yz : torch.tensor
        the logits (before sigmoid) of the second Multi-Bernoulli
    az : torch.tensor
        the logits of the third Multi-Bernoulli which act as an activation
        of the equality.
        
    Returns
    -------
    log_p0 : torch.tensor
        the bitwise log probability that the two Multi-Bernoulli are not equal
        and the third is zero.
    log_p1 : torch.tensor
        the bitwise log probability that the two Multi-Bernoulli are equal
        or the third is one.
        
    Notes
    -----
    xz and yz need not to have the same shape, but they should
    be broadcastable.
    """
    xp, yp, ap, xn, yn, an = map(logsigmoid, (xz, yz, az, -xz, -yz, -az))
    log_p0 = torch_logsumexp(an + xp + yn, an + xn + yp)
    log_p1 = torch_logsumexp(ap, an + xp + yp, an + xn + yn)
    return log_p0, log_p1
def multi_bernoulli_equality(xz, yz):
    """
    Compute the bitwise log probability that two Multi-Bernoulli are equal.
    
    Parameters
    ----------
    xz : torch.tensor
        the logits (before sigmoid) of the first Multi-Bernoulli
    yz : torch.tensor
        the logits (before sigmoid) of the second Multi-Bernoulli
        
    Returns
    -------
    log_p0 : torch.tensor
        the bitwise log probability that the two Multi-Bernoulli are not equal
    log_p1 : torch.tensor
        the bitwise log probability that the two Multi-Bernoulli are equal
        
    Notes
    -----
    xz and yz need not to have the same shape, but they should
    be broadcastable.
    """
    xp, yp, xn, yn = map(logsigmoid, (xz, yz, -xz, -yz))
    log_p0 = torch_logsumexp(xp + yn, xn + yp)
    log_p1 = torch_logsumexp(xp + yp, xn + yn)
    return log_p0, log_p1
Example #3
0
def log_hamming_binomial(log_p10, log_p11, log_p20, log_p21):
    """
    Computes the log probabilities of each Hamming Binomial events, parameterized
    with p1 and p2.
    
    Parameters
    ----------
    log_p10 : torch.tensor (dtype=torch.float)
        The log probability of each bits to be zero of the first random vector.
        The Hamming Binomial is considered to be on the last dim. shape=(a1,a2,a3,...,am,n)
        where n is the number of bits for each vectors. a1,a2,a3,...,am are arbitrary but
        should be broadcastable with the other inputs.
    
    log_p11 : torch.tensor (dtype=torch.float)
        The log probability of each bits to be one of the first random vector.
        The Hamming Binomial is considered to be on the last dim. shape=(a1,a2,a3,...,am,n)
        where n is the number of bits for each vectors. a1,a2,a3,...,am are arbitrary but
        should be broadcastable with the other inputs.
        
    log_p20 : torch.tensor (dtype=torch.float)
        The log probability of each bits to be zero of the second random vector.
        The Hamming Binomial is considered to be on the last dim. shape=(a1,a2,a3,...,am,n)
        where n is the number of bits for each vectors. a1,a2,a3,...,am are arbitrary but
        should be broadcastable with the other inputs.
    
    log_p21 : torch.tensor (dtype=torch.float)
        The log probability of each bits to be one of the second random vector.
        The Hamming Binomial is considered to be on the last dim. shape=(a1,a2,a3,...,am,n)
        where n is the number of bits for each vectors. a1,a2,a3,...,am are arbitrary but
        should be broadcastable with the other inputs.
        
    Returns
    -------
    log_hb : torch.tensor (dtype=torch.float)
        The log probability of each Hamming Binomial events. shape=(a1,a2,a3,...,am,n+1).
        log_pb[i1,i2,i3,...,am,k] is the log probability that the hamming distance between
        the two random Multi-Bernoulli parameterized by log_p11.exp()[i1,i2,i3,...,am] and
        log_p21.exp()[i1,i2,i3,...,am] be k.
        
    Notes
    -----
    see log_poisson_binomial's notes.
    """
    log_q0 = torch_logsumexp(log_p11 + log_p21, log_p10 + log_p20)
    log_q1 = torch_logsumexp(log_p11 + log_p20, log_p10 + log_p21)
    return log_poisson_binomial(log_q0, log_q1)
Example #4
0
def _log_poisson_binomial(log_q0, log_q1):
    bs, n = log_q0.shape
    dtype, device = log_q0.dtype, log_q0.device
    ninf = torch.tensor(-np.inf, dtype=dtype, device=device).expand(1, bs)
    z = torch.zeros((1, bs), dtype=dtype, device=device)
    for n,(a,b) in enumerate(zip(log_q0.permute(1,0), log_q1.permute(1,0))):
        w0 = torch.cat([a + z, ninf], dim=0)
        w1 = torch.cat([ninf, b + z], dim=0)
        z = torch_logsumexp(w0, w1)
    return z.permute(1,0)