def batch_exact_joint_entropy(probs_B_K_C, prev_joint_probs_M_K, chunk_size, device, out_joint_entropies_B): """This one switches between devices, too.""" for joint_entropies_b, probs_b_K_C in tqdm(torch_utils.split_tensors( out_joint_entropies_B, probs_B_K_C, chunk_size), unit_scale=chunk_size): joint_entropies_b.copy_(joint_entropy_exact.batch( probs_b_K_C.to(device), prev_joint_probs_M_K), non_blocking=True) return joint_entropies_b
def test_exact_joint_entropy(): probs_B_K_C = torch.exp(logits_B_K_C) basic_result = basic_exact_joint_entropy(logits_B_K_C[0:3]).item() joint_probs_result = exact.entropy_from_M_K( exact.joint_probs_M_K(probs_B_K_C[0:3])).item() prev_joint_probs_M_K = exact.joint_probs_M_K(probs_B_K_C[0:2]) batch_result = exact.batch(probs_B_K_C[2][None], prev_joint_probs_M_K).item() assert np.isclose(basic_result, joint_probs_result) assert np.isclose(basic_result, batch_result)
def test_unified_sampling_joint_entropy(): K = logits_B_K_C.shape[1] probs_B_K_C = torch.exp(logits_B_K_C) basic_result = basic_exact_joint_entropy(logits_B_K_C[0:3]).item() samples_result = sampling.from_M_K( sampling.sample_M_K_unified(probs_B_K_C[0:3], 10000 // K)).item() assert np.isclose(basic_result, samples_result, atol=0.05) prev_joint_probs_M_K = exact.joint_probs_M_K(probs_B_K_C[0:2]) exact_batch_results = exact.batch(probs_B_K_C, prev_joint_probs_M_K).numpy() samples_M_K = sampling.sample_M_K(probs_B_K_C[0:2], 10000 // K) sampling_batch_results = sampling.batch(probs_B_K_C, samples_M_K).numpy() assert np.allclose(exact_batch_results, sampling_batch_results, atol=0.05)