コード例 #1
0
ファイル: multi_bald.py プロジェクト: Rexhaif/BatchBALD
def batch_exact_joint_entropy(probs_B_K_C, prev_joint_probs_M_K, chunk_size,
                              device, out_joint_entropies_B):
    """This one switches between devices, too."""
    for joint_entropies_b, probs_b_K_C in tqdm(torch_utils.split_tensors(
            out_joint_entropies_B, probs_B_K_C, chunk_size),
                                               unit_scale=chunk_size):
        joint_entropies_b.copy_(joint_entropy_exact.batch(
            probs_b_K_C.to(device), prev_joint_probs_M_K),
                                non_blocking=True)

    return joint_entropies_b
コード例 #2
0
ファイル: test_joint_entropy.py プロジェクト: tmfs10/ical
def test_exact_joint_entropy():
    probs_B_K_C = torch.exp(logits_B_K_C)

    basic_result = basic_exact_joint_entropy(logits_B_K_C[0:3]).item()
    joint_probs_result = exact.entropy_from_M_K(
        exact.joint_probs_M_K(probs_B_K_C[0:3])).item()

    prev_joint_probs_M_K = exact.joint_probs_M_K(probs_B_K_C[0:2])
    batch_result = exact.batch(probs_B_K_C[2][None],
                               prev_joint_probs_M_K).item()

    assert np.isclose(basic_result, joint_probs_result)
    assert np.isclose(basic_result, batch_result)
コード例 #3
0
ファイル: test_joint_entropy.py プロジェクト: tmfs10/ical
def test_unified_sampling_joint_entropy():
    K = logits_B_K_C.shape[1]
    probs_B_K_C = torch.exp(logits_B_K_C)

    basic_result = basic_exact_joint_entropy(logits_B_K_C[0:3]).item()
    samples_result = sampling.from_M_K(
        sampling.sample_M_K_unified(probs_B_K_C[0:3], 10000 // K)).item()

    assert np.isclose(basic_result, samples_result, atol=0.05)

    prev_joint_probs_M_K = exact.joint_probs_M_K(probs_B_K_C[0:2])
    exact_batch_results = exact.batch(probs_B_K_C,
                                      prev_joint_probs_M_K).numpy()

    samples_M_K = sampling.sample_M_K(probs_B_K_C[0:2], 10000 // K)
    sampling_batch_results = sampling.batch(probs_B_K_C, samples_M_K).numpy()

    assert np.allclose(exact_batch_results, sampling_batch_results, atol=0.05)