def sample_M_K(probs_N_K_C, S=1000): probs_N_K_C = probs_N_K_C.double() K = probs_N_K_C.shape[1] # Given the probabilities in probs_N_K_C, take S samples from # class labels according to the given distributions. In essence, # we are sampling a possible configuration of class labels y_1:n # for all samples choices_N_K_S = batch_multi_choices(probs_N_K_C, S).long() # Insert an empty dimension expanded_choices_N_K_K_S = choices_N_K_S[:, None, :, :] expanded_probs_N_K_K_C = probs_N_K_C[:, :, None, :] # From the sampled class labels gather the probabilities of those classes probs_N_K_K_S = gather_expand( expanded_probs_N_K_K_C, dim=-1, index=expanded_choices_N_K_K_S, ) # Calculate the probability of all observed class labels # exp sum log seems necessary to avoid 0s? probs_K_K_S = torch.exp( torch.sum(torch.log(probs_N_K_K_S), dim=0, keepdim=False)) samples_K_M = probs_K_K_S.reshape((K, -1)) samples_M_K = samples_K_M.t() return samples_M_K
def sample_M_K_unified(probs_N_K_C, S=1000): probs_N_K_C = probs_N_K_C.double() K = probs_N_K_C.shape[1] choices_N_1_M = batch_multi_choices(torch.mean(probs_N_K_C, dim=1, keepdim=True), S * K).long() probs_N_K_M = gather_expand(probs_N_K_C, dim=-1, index=choices_N_1_M) # exp sum log seems necessary to avoid 0s? # probs_K_M = torch.exp(torch.sum(torch.log(probs_N_K_M), dim=0, keepdim=False)) probs_K_M = torch.prod(probs_N_K_M, dim=0, keepdim=False) samples_M_K = probs_K_M.t() return samples_M_K
def sample_M_K(probs_N_K_C, S=1000): probs_N_K_C = probs_N_K_C.double() K = probs_N_K_C.shape[1] choices_N_K_S = batch_multi_choices(probs_N_K_C, S).long() expanded_choices_N_K_K_S = choices_N_K_S[:, None, :, :] expanded_probs_N_K_K_C = probs_N_K_C[:, :, None, :] probs_N_K_K_S = gather_expand(expanded_probs_N_K_K_C, dim=-1, index=expanded_choices_N_K_K_S) # exp sum log seems necessary to avoid 0s? probs_K_K_S = torch.exp(torch.sum(torch.log(probs_N_K_K_S), dim=0, keepdim=False)) samples_K_M = probs_K_K_S.reshape((K, -1)) samples_M_K = samples_K_M.t() return samples_M_K
def basic_exact_joint_entropy(logits_N_K_C): device = logits_N_K_C.device N, K, C = logits_N_K_C.shape entropy = torch.zeros(1, dtype=torch.float64, device=device) N_K_C = torch.exp(logits_N_K_C.double()) for index in itertools.product(range(C), repeat=N): expanded_index = torch.as_tensor(index, device=device).reshape( (-1, 1, 1)) # N x K x 1 N_K = torch_utils.gather_expand(N_K_C, dim=2, index=expanded_index) # N x K reshaped_N_K = N_K.reshape(N, K) # K joint_prob = torch.prod(reshaped_N_K, dim=0) # print(joint_prob.shape) joint_prob = torch.mean(joint_prob) entropy_bit = -joint_prob * torch.log(joint_prob) entropy += entropy_bit return entropy