示例#1
0
def sample_M_K(probs_N_K_C, S=1000):
    probs_N_K_C = probs_N_K_C.double()

    K = probs_N_K_C.shape[1]

    # Given the probabilities in probs_N_K_C, take S samples from
    # class labels according to the given distributions. In essence,
    # we are sampling a possible configuration of class labels y_1:n
    # for all samples
    choices_N_K_S = batch_multi_choices(probs_N_K_C, S).long()

    # Insert an empty dimension
    expanded_choices_N_K_K_S = choices_N_K_S[:, None, :, :]
    expanded_probs_N_K_K_C = probs_N_K_C[:, :, None, :]

    # From the sampled class labels gather the probabilities of those classes
    probs_N_K_K_S = gather_expand(
        expanded_probs_N_K_K_C,
        dim=-1,
        index=expanded_choices_N_K_K_S,
    )
    # Calculate the probability of all observed class labels
    # exp sum log seems necessary to avoid 0s?
    probs_K_K_S = torch.exp(
        torch.sum(torch.log(probs_N_K_K_S), dim=0, keepdim=False))
    samples_K_M = probs_K_K_S.reshape((K, -1))

    samples_M_K = samples_K_M.t()
    return samples_M_K
示例#2
0
文件: sampling.py 项目: tmfs10/ical
def sample_M_K_unified(probs_N_K_C, S=1000):
    probs_N_K_C = probs_N_K_C.double()

    K = probs_N_K_C.shape[1]

    choices_N_1_M = batch_multi_choices(torch.mean(probs_N_K_C, dim=1, keepdim=True), S * K).long()
    probs_N_K_M = gather_expand(probs_N_K_C, dim=-1, index=choices_N_1_M)

    # exp sum log seems necessary to avoid 0s?
    # probs_K_M = torch.exp(torch.sum(torch.log(probs_N_K_M), dim=0, keepdim=False))
    probs_K_M = torch.prod(probs_N_K_M, dim=0, keepdim=False)

    samples_M_K = probs_K_M.t()
    return samples_M_K
示例#3
0
文件: sampling.py 项目: tmfs10/ical
def sample_M_K(probs_N_K_C, S=1000):
    probs_N_K_C = probs_N_K_C.double()

    K = probs_N_K_C.shape[1]

    choices_N_K_S = batch_multi_choices(probs_N_K_C, S).long()

    expanded_choices_N_K_K_S = choices_N_K_S[:, None, :, :]
    expanded_probs_N_K_K_C = probs_N_K_C[:, :, None, :]

    probs_N_K_K_S = gather_expand(expanded_probs_N_K_K_C, dim=-1, index=expanded_choices_N_K_K_S)
    # exp sum log seems necessary to avoid 0s?
    probs_K_K_S = torch.exp(torch.sum(torch.log(probs_N_K_K_S), dim=0, keepdim=False))
    samples_K_M = probs_K_K_S.reshape((K, -1))

    samples_M_K = samples_K_M.t()
    return samples_M_K
示例#4
0
def basic_exact_joint_entropy(logits_N_K_C):
    device = logits_N_K_C.device
    N, K, C = logits_N_K_C.shape
    entropy = torch.zeros(1, dtype=torch.float64, device=device)
    N_K_C = torch.exp(logits_N_K_C.double())
    for index in itertools.product(range(C), repeat=N):
        expanded_index = torch.as_tensor(index, device=device).reshape(
            (-1, 1, 1))
        # N x K x 1
        N_K = torch_utils.gather_expand(N_K_C, dim=2, index=expanded_index)
        # N x K
        reshaped_N_K = N_K.reshape(N, K)
        # K
        joint_prob = torch.prod(reshaped_N_K, dim=0)
        # print(joint_prob.shape)
        joint_prob = torch.mean(joint_prob)
        entropy_bit = -joint_prob * torch.log(joint_prob)

        entropy += entropy_bit

    return entropy