def sample_partitions(B, N, K, alpha=1.0, rand_K=True, device='cpu'): pi = Dirichlet(alpha * torch.ones(K)).sample([B]).to(device) if rand_K: to_use = (torch.rand(B, K) < 0.5).float().to(device) to_use[..., 0] = 1 pi = pi * to_use pi = pi / pi.sum(1, keepdim=True) labels = Categorical(probs=pi).sample([N]).to(device) labels = labels.transpose(0, 1).contiguous() return labels
def sample_labels(B, N, K_low, K_high, alpha=1.0): pi = Dirichlet(alpha * torch.ones(K_high)).sample([B]) K = torch.randint(K_low, K_high + 1, size=(B, )) to_use = torch.zeros(B, K_high).int() for i, k in enumerate(K): to_use[i, :k] = 1 pi = pi * to_use pi = pi / pi.sum(1, keepdim=True) labels = Categorical(probs=pi).sample([N]) labels = labels.transpose(0, 1).contiguous() return labels