Пример #1
0
def compute_acquisition_bag(
    bayesian_model: nn.Module,
    acquisition_function: AcquisitionFunction,
    available_loader,
    num_classes: int,
    k: int,
    b: int,
    initial_percentage: int,
    reduce_percentage: int,
    device=None,
) -> AcquisitionBatch:
    if acquisition_function != AcquisitionFunction.random:
        result = reduced_eval_consistent_bayesian_model(
            bayesian_model=bayesian_model,
            acquisition_function=acquisition_function,
            num_classes=num_classes,
            k=k,
            initial_percentage=initial_percentage,
            reduce_percentage=reduce_percentage,
            target_size=b,
            available_loader=available_loader,
            device=device,
        )

        scores_B = result.scores_B
        subset_split = result.subset_split
        result = None

        top_k_scores, top_k_indices = scores_B.topk(b, largest=True, sorted=True)

        top_k_scores = top_k_scores.numpy()

        # Map our indices to the available_loader dataset.
        top_k_indices = subset_split.get_dataset_indices(top_k_indices.numpy())

        print(f"Acquisition bag: {top_k_indices}")
        print(f"Scores: {top_k_scores}")

        return AcquisitionBatch(top_k_indices, top_k_scores, None)
    else:
        picked_indices = torch.randperm(len(available_loader.dataset))[:b].numpy()

        print(f"Acquisition bag: {picked_indices}")

        return AcquisitionBatch(picked_indices, [0.0] * b, None)
Пример #2
0
def compute_acs_fw_batch(
    bayesian_model: nn.Module,
    available_loader,
    num_classes,
    k,
    b,
    target_size,
    initial_percentage,
    reduce_percentage,
    max_entropy_bag_size,
    device=None,
) -> AcquisitionBatch:
    result = reduced_eval_consistent_bayesian_model(
        bayesian_model=bayesian_model,
        acquisition_function=AcquisitionFunction.bald,
        num_classes=num_classes,
        k=k,
        initial_percentage=initial_percentage,
        reduce_percentage=reduce_percentage,
        target_size=target_size,
        available_loader=available_loader,
        device=device,
    )

    start_time = time.process_time()

    B, K, C = list(
        result.logits_B_K_C.shape)  # (pool size, mc dropout samples, classes)
    py = result.logits_B_K_C.exp_().mean(dim=1)

    num_projections = 10
    gamma = 0.7
    assert K >= num_projections
    cs = ProjectedFrankWolfe(py,
                             result.logits_B_K_C[:, :num_projections, :],
                             num_projections,
                             gamma=gamma)

    end_time = time.process_time()
    global_acquisition_bag = cs.build(M=b).tolist()
    s = set(global_acquisition_bag)
    perm = torch.randperm(B)
    bi = 0
    while len(global_acquisition_bag) < b:
        k = perm[bi].item()
        if k not in s:
            global_acquisition_bag += [k]
            s.add(k)
        bi += 1
    time_taken = end_time - start_time
    print('ack time taken', time_taken)
    acquisition_bag_scores = []

    return AcquisitionBatch(global_acquisition_bag, acquisition_bag_scores,
                            None), time_taken
Пример #3
0
def compute_multi_bald_batch(
    bayesian_model: nn.Module,
    available_loader,
    num_classes,
    k,
    b,
    target_size,
    initial_percentage,
    reduce_percentage,
    device=None,
) -> AcquisitionBatch:
    result = reduced_eval_consistent_bayesian_model(
        bayesian_model=bayesian_model,
        acquisition_function=AcquisitionFunction.bald,
        num_classes=num_classes,
        k=k,
        initial_percentage=initial_percentage,
        reduce_percentage=reduce_percentage,
        target_size=target_size,
        available_loader=available_loader,
        device=device,
    )
    start_time = time.process_time()

    subset_split = result.subset_split

    partial_multi_bald_B = result.scores_B
    # Now we can compute the conditional entropy
    conditional_entropies_B = joint_entropy_exact.batch_conditional_entropy_B(
        result.logits_B_K_C)

    # We turn the logits into probabilities.
    probs_B_K_C = result.logits_B_K_C.exp_()

    # Don't need the result anymore.
    result = None

    torch_utils.gc_cuda()
    # torch_utils.cuda_meminfo()

    with torch.no_grad():
        num_samples_per_ws = 40000 // k
        num_samples = num_samples_per_ws * k

        if device.type == "cuda":
            # KC_memory = k*num_classes*8
            sample_MK_memory = num_samples * k * 8
            MC_memory = num_samples * num_classes * 8
            copy_buffer_memory = 256 * num_samples * num_classes * 8
            slack_memory = 2 * 2**30
            multi_bald_batch_size = (torch_utils.get_cuda_available_memory() -
                                     (sample_MK_memory + copy_buffer_memory +
                                      slack_memory)) // MC_memory

            global compute_multi_bald_bag_multi_bald_batch_size
            if compute_multi_bald_bag_multi_bald_batch_size != multi_bald_batch_size:
                compute_multi_bald_bag_multi_bald_batch_size = multi_bald_batch_size
                print(
                    f"New compute_multi_bald_bag_multi_bald_batch_size = {multi_bald_batch_size}"
                )
        else:
            multi_bald_batch_size = 16

        subset_acquisition_bag = []
        global_acquisition_bag = []
        acquisition_bag_scores = []

        # We use this for early-out in the b==0 case.
        MIN_SPREAD = 0.1

        if b == 0:
            b = 100
            early_out = True
        else:
            early_out = False

        prev_joint_probs_M_K = None
        prev_samples_M_K = None
        for i in range(b):
            torch_utils.gc_cuda()

            if i > 0:
                # Compute the joint entropy
                joint_entropies_B = torch.empty((len(probs_B_K_C), ),
                                                dtype=torch.float64)

                exact_samples = num_classes**i
                if exact_samples <= num_samples:
                    prev_joint_probs_M_K = joint_entropy_exact.joint_probs_M_K(
                        probs_B_K_C[subset_acquisition_bag[-1]][None].to(
                            device),
                        prev_joint_probs_M_K=prev_joint_probs_M_K,
                    )

                    # torch_utils.cuda_meminfo()
                    batch_exact_joint_entropy(probs_B_K_C,
                                              prev_joint_probs_M_K,
                                              multi_bald_batch_size, device,
                                              joint_entropies_B)
                else:
                    if prev_joint_probs_M_K is not None:
                        prev_joint_probs_M_K = None
                        torch_utils.gc_cuda()

                    # Gather new traces for the new subset_acquisition_bag.
                    prev_samples_M_K = joint_entropy_sampling.sample_M_K(
                        probs_B_K_C[subset_acquisition_bag].to(device),
                        S=num_samples_per_ws)

                    # torch_utils.cuda_meminfo()
                    for joint_entropies_b, probs_b_K_C in with_progress_bar(
                            torch_utils.split_tensors(joint_entropies_B,
                                                      probs_B_K_C,
                                                      multi_bald_batch_size),
                            unit_scale=multi_bald_batch_size,
                    ):
                        joint_entropies_b.copy_(joint_entropy_sampling.batch(
                            probs_b_K_C.to(device), prev_samples_M_K),
                                                non_blocking=True)

                        # torch_utils.cuda_meminfo()

                    prev_samples_M_K = None
                    torch_utils.gc_cuda()

                partial_multi_bald_B = joint_entropies_B - conditional_entropies_B
                joint_entropies_B = None

            # Don't allow reselection
            partial_multi_bald_B[subset_acquisition_bag] = -math.inf

            winner_index = partial_multi_bald_B.argmax().item()

            # Actual MultiBALD is:
            actual_multi_bald_B = partial_multi_bald_B[
                winner_index] - torch.sum(
                    conditional_entropies_B[subset_acquisition_bag])
            actual_multi_bald_B = actual_multi_bald_B.item()

            print(f"Actual MultiBALD: {actual_multi_bald_B}")

            # If we early out, we don't take the point that triggers the early out.
            # Only allow early-out after acquiring at least 1 sample.
            if early_out and i > 1:
                current_spread = actual_multi_bald_B[
                    winner_index] - actual_multi_bald_B.median()
                if current_spread < MIN_SPREAD:
                    print("Early out")
                    break

            acquisition_bag_scores.append(actual_multi_bald_B)

            subset_acquisition_bag.append(winner_index)
            # We need to map the index back to the actual dataset.
            global_acquisition_bag.append(
                subset_split.get_dataset_indices([winner_index]).item())

            print(
                f"Acquisition bag: {sorted(global_acquisition_bag)}, num_ack: {i}"
            )

    end_time = time.process_time()
    time_taken = end_time - start_time
    print('ack time taken', time_taken)

    return AcquisitionBatch(global_acquisition_bag, acquisition_bag_scores,
                            None), time_taken
Пример #4
0
def compute_ical_pointwise(
    bayesian_model: nn.Module,
    available_loader,
    num_classes,
    k,
    b,
    target_size,
    initial_percentage,
    reduce_percentage,
    max_batch_compute_size,
    hsic_compute_batch_size,
    hsic_kernel_name,
    max_greedy_iterations=0,
    hsic_resample=True,
    device=None,
    store=None,
    num_to_condense=200,
    num_inference_for_marginal_stat=0,
    use_orig_condense=False,
) -> AcquisitionBatch:
    assert hsic_compute_batch_size is not None
    assert hsic_kernel_name is not None

    result = reduced_eval_consistent_bayesian_model(
        bayesian_model=bayesian_model,
        acquisition_function=AcquisitionFunction.bald,
        num_classes=num_classes,
        k=k,
        initial_percentage=initial_percentage,
        reduce_percentage=reduce_percentage,
        target_size=target_size,
        available_loader=available_loader,
        device=device,
    )

    start_time = time.process_time()

    probs_B_K_C = result.logits_B_K_C.exp_()
    B, K, C = list(result.logits_B_K_C.shape)

    dist_B_K_C = tdist.categorical.Categorical(result.logits_B_K_C.view(
        -1, C))  # shape B*K x C
    sample_B_K_C = dist_B_K_C.sample([1])  # shape 1 x B*K
    assert list(sample_B_K_C.shape) == [1, B * K]
    sample_B_K_C = sample_B_K_C[0]
    oh_sample = torch.eye(C)[sample_B_K_C]  # B*K x C
    oh_sample = oh_sample.view(B, K, C)
    sample_B_K_C = oh_sample  #.to(device)

    kernel_fn = getattr(hsic, hsic_kernel_name + '_kernels')

    dist_matrices = []
    bs = 0
    while bs < B:
        be = min(B, bs + hsic_compute_batch_size)
        dist_matrix = hsic.sqdist(sample_B_K_C[bs:be].permute(
            [1, 0, 2]))  # n=K, d=B, k=C
        dist_matrices += [dist_matrix]
        bs = be
    dist_matrices = torch.cat(dist_matrices, dim=-1)  #.to(device)

    bs = 0
    while bs < B:
        be = min(B, bs + hsic_compute_batch_size)
        dist_matrices[:, :, bs:be] = kernel_fn(dist_matrices[:, :, bs:be])
        bs = be
    kernel_matrices = dist_matrices.permute([2, 0, 1]).to(device)  # B, K, K
    assert list(kernel_matrices.shape) == [
        B, K, K
    ], "%s == %s" % (kernel_matrices.shape, [B, K, K])

    ack_bag = []
    global_acquisition_bag = []
    acquisition_bag_scores = []
    batch_kernel = None
    print('Computing HSIC for', B, 'points')

    score_sort = torch.sort(result.scores_B, descending=True)
    score_sort_idx = score_sort[1]
    score_sort = score_sort[0]
    indices_to_condense = np.random.randint(low=0,
                                            high=score_sort_idx.shape[0],
                                            size=num_to_condense)

    if max_greedy_iterations == 0:
        max_greedy_iterations = b
    assert b % max_greedy_iterations == 0, "acquisition batch size must be a multiple of (ical_)max_greedy_iterations!"
    greedy_ack_batch_size = b // max_greedy_iterations
    print('max_greedy_iterations', max_greedy_iterations,
          'greedy_ack_batch_size', greedy_ack_batch_size)

    div_condense_num = 10

    for ackb_i in range(max_greedy_iterations):
        bs = 0
        hsic_scores = []
        condense_kernels = kernel_matrices[indices_to_condense].permute(
            [1, 2, 0]).mean(dim=-1, keepdim=True).unsqueeze(0)  # 1, K, K, 1
        div_indices_to_condense = np.random.randint(
            low=0, high=score_sort_idx.shape[0], size=max_batch_compute_size)
        if use_orig_condense:
            div_indices_to_condense = indices_to_condense
        div_size = div_indices_to_condense.shape[0]
        div_condense_kernels = kernel_matrices[
            div_indices_to_condense].unsqueeze(1)  # div_size, 1, K, K
        while bs < B:
            be = min(B, bs + hsic_compute_batch_size)
            m = be - bs

            if batch_kernel is None:
                hsic_scores += [
                    hsic.total_hsic_parallel(
                        torch.cat([
                            condense_kernels.repeat([m, 1, 1, 1]),
                            kernel_matrices[bs:be].unsqueeze(-1),
                        ],
                                  dim=-1).to(device))
                ]
            else:
                num_ack = len(ack_bag)
                if num_inference_for_marginal_stat > 0:
                    marginal_stat_K_idx = torch.randperm(
                        K)[:num_inference_for_marginal_stat]
                else:
                    marginal_stat_K_idx = torch.arange(K)
                K2 = marginal_stat_K_idx.shape[0]

                if K2 < K:
                    cur_og_batch_kernel = batch_kernel[
                        marginal_stat_K_idx][:, marginal_stat_K_idx][
                            None, :, :, None].repeat([m, 1, 1,
                                                      1])  # M, K2, K2, 1
                    cur_batch_kernel = (
                        cur_og_batch_kernel * num_ack +
                        kernel_matrices[bs:be][:, marginal_stat_K_idx]
                        [:, :, marginal_stat_K_idx].unsqueeze(-1)) / (
                            num_ack + 1)  # M, K2, K2, 1
                    cur_div_condense_kernels = div_condense_kernels[:, :,
                                                                    marginal_stat_K_idx][:, :, :, marginal_stat_K_idx].repeat(
                                                                        [
                                                                            1,
                                                                            m,
                                                                            1,
                                                                            1
                                                                        ]
                                                                    ).view(
                                                                        -1, K2,
                                                                        K2, 1
                                                                    )  # div_size*M, K, K, 1
                else:
                    cur_og_batch_kernel = batch_kernel[None, :, :,
                                                       None].repeat(
                                                           [m, 1, 1,
                                                            1])  # M, K2, K2, 1
                    cur_batch_kernel = (cur_og_batch_kernel * num_ack +
                                        kernel_matrices[bs:be].unsqueeze(-1)
                                        ) / (num_ack + 1)  # M, K2, K2, 1
                    cur_div_condense_kernels = div_condense_kernels.repeat(
                        [1, m, 1, 1]).view(-1, K2, K2,
                                           1)  # div_size*M, K, K, 1
                assert list(
                    cur_batch_kernel.shape) == [m, K2, K2,
                                                1], cur_batch_kernel.shape
                assert list(cur_div_condense_kernels.shape) == [
                    m * div_size, K2, K2, 1
                ], cur_div_condense_kernels.shape
                hsic_scores1 = hsic.total_hsic_parallel(
                    torch.cat(
                        [
                            cur_div_condense_kernels,
                            cur_batch_kernel.unsqueeze(0).repeat([
                                div_size, 1, 1, 1, 1
                            ]).view(-1, K2, K2, 1),  # div_size*M, K2, K2, 1
                        ],  # div_size*M, K2, K2, 2
                        dim=-1).to(device))
                hsic_scores2 = hsic.total_hsic_parallel(
                    torch.cat([
                        cur_div_condense_kernels,
                        cur_og_batch_kernel.unsqueeze(0).repeat(
                            [div_size, 1, 1, 1, 1]).view(-1, K2, K2, 1),
                    ],
                              dim=-1).to(device))

                if not use_orig_condense:
                    to_add = max(hsic_scores1.min().item(),
                                 hsic_scores2.min().item())
                    hsic_scores1 += to_add + 1e-8
                    hsic_scores2 += to_add + 1e-8
                    scores = (hsic_scores1 / hsic_scores2).view(div_size, m)
                    scores = torch.max(scores,
                                       torch.tensor(1., device=scores.device))
                    marginal_improvement_ratio = scores.mean(
                        0)  # marginal fractional improvement in dependency

                if K2 == K:
                    hsic_scores1 = hsic.total_hsic_parallel(
                        torch.cat(
                            [
                                condense_kernels.repeat([m, 1, 1, 1
                                                         ]),  # M, K, K, 1
                                cur_batch_kernel,
                            ],
                            dim=-1).to(device))
                else:
                    cur_og_batch_kernel = batch_kernel[None, :, :,
                                                       None].repeat(
                                                           [m, 1, 1,
                                                            1])  # M, K, K, 1
                    cur_batch_kernel = (cur_og_batch_kernel * num_ack +
                                        kernel_matrices[bs:be].unsqueeze(-1)
                                        ) / (num_ack + 1)  # M, K, K, 1
                    hsic_scores1 = hsic.total_hsic_parallel(
                        torch.cat(
                            [
                                condense_kernels.repeat([m, 1, 1, 1
                                                         ]),  # M, K, K, 1
                                cur_batch_kernel,  # M, K, K, 1
                            ],
                            dim=-1).to(device))
                if use_orig_condense:
                    scores = hsic_scores1 - hsic_scores2.view(div_size, m)
                    scores = torch.max(scores,
                                       torch.tensor(0., device=scores.device))
                    hsic_scores += [scores.mean(0)]
                else:
                    hsic_scores1 *= (marginal_improvement_ratio - 1)
                    hsic_scores += [hsic_scores1]
                torch.cuda.empty_cache()
            bs = be

        hsic_scores = torch.cat(hsic_scores)
        hsic_scores[ack_bag] = -math.inf

        _, sorted_idxes = torch.sort(hsic_scores, descending=True)
        winner_idxes = []
        for g_ack_i in range(greedy_ack_batch_size):
            winner_idxes += [sorted_idxes[g_ack_i].item()]

        old_num_acks = len(ack_bag)
        ack_bag += winner_idxes
        new_num_acks = len(ack_bag)
        global_acquisition_bag += [
            i.item()
            for i in result.subset_split.get_dataset_indices(winner_idxes)
        ]
        acquisition_bag_scores += [s.item() for s in hsic_scores[winner_idxes]]
        print('winner score', result.scores_B[winner_idxes].mean().item(),
              ', hsic_score', hsic_scores[winner_idxes].mean().item(),
              ', ackb_i', ackb_i)
        if batch_kernel is None:
            batch_kernel = kernel_matrices[winner_idxes].mean(0)  # K, K
        else:
            batch_kernel = (
                batch_kernel * old_num_acks +
                kernel_matrices[winner_idxes].sum(0)) / new_num_acks
        assert len(batch_kernel.shape) == 2

        result.scores_B[winner_idxes] = -math.inf
        score_sort = torch.sort(result.scores_B, descending=True)
        score_sort_idx = score_sort[1]
        score_sort = score_sort[0]
        if hsic_resample:
            indices_to_condense = np.random.randint(
                low=0, high=score_sort_idx.shape[0], size=num_to_condense)

    assert len(ack_bag) == b
    np.set_printoptions(precision=3, suppress=True)

    end_time = time.process_time()
    time_taken = end_time - start_time
    print('ack time taken', time_taken)
    return AcquisitionBatch(global_acquisition_bag, acquisition_bag_scores,
                            None), time_taken
Пример #5
0
def compute_ical(
    bayesian_model: nn.Module,
    available_loader,
    num_classes,
    k,
    b,
    target_size,
    initial_percentage,
    reduce_percentage,
    max_batch_compute_size,
    hsic_compute_batch_size,
    hsic_kernel_name,
    max_greedy_iterations=0,
    hsic_resample=True,
    device=None,
    store=None,
    num_to_condense=200,
) -> AcquisitionBatch:
    assert hsic_compute_batch_size is not None
    assert hsic_kernel_name is not None

    result = reduced_eval_consistent_bayesian_model(
        bayesian_model=bayesian_model,
        acquisition_function=AcquisitionFunction.bald,
        num_classes=num_classes,
        k=k,
        initial_percentage=initial_percentage,
        reduce_percentage=reduce_percentage,
        target_size=target_size,
        available_loader=available_loader,
        device=device,
    )

    start_time = time.process_time()

    probs_B_K_C = result.logits_B_K_C.exp_()
    B, K, C = list(result.logits_B_K_C.shape)

    dist_B_K_C = tdist.categorical.Categorical(result.logits_B_K_C.view(
        -1, C))  # shape B*K x C
    sample_B_K_C = dist_B_K_C.sample([1])  # shape 1 x B*K
    assert list(sample_B_K_C.shape) == [1, B * K]
    sample_B_K_C = sample_B_K_C[0]
    oh_sample = torch.eye(C)[sample_B_K_C]  # B*K x C
    oh_sample = oh_sample.view(B, K, C)
    sample_B_K_C = oh_sample  #.to(device)

    kernel_fn = getattr(hsic, hsic_kernel_name + '_kernels')

    dist_matrices = []
    bs = 0
    while bs < B:
        be = min(B, bs + hsic_compute_batch_size)
        dist_matrix = hsic.sqdist(sample_B_K_C[bs:be].permute(
            [1, 0, 2]))  # n=K, d=B, k=C
        dist_matrices += [dist_matrix]
        bs = be
    dist_matrices = torch.cat(dist_matrices, dim=-1)  #.to(device)

    bs = 0
    while bs < B:
        be = min(B, bs + hsic_compute_batch_size)
        dist_matrices[:, :, bs:be] = kernel_fn(dist_matrices[:, :, bs:be])
        bs = be
    kernel_matrices = dist_matrices.permute([2, 0, 1]).to(device)  # B, K, K
    assert list(kernel_matrices.shape) == [
        B, K, K
    ], "%s == %s" % (kernel_matrices.shape, [B, K, K])

    ack_bag = []
    global_acquisition_bag = []
    acquisition_bag_scores = []
    batch_kernel = None
    print('Computing HSIC for', B, 'points')

    score_sort = torch.sort(result.scores_B, descending=True)
    score_sort_idx = score_sort[1]
    score_sort = score_sort[0]
    indices_to_condense = np.random.randint(low=0,
                                            high=score_sort_idx.shape[0],
                                            size=num_to_condense)

    if max_greedy_iterations == 0:
        max_greedy_iterations = b
    assert b % max_greedy_iterations == 0, "acquisition batch size must be a multiple of (ical_)max_greedy_iterations!"
    greedy_ack_batch_size = b // max_greedy_iterations
    print('max_greedy_iterations', max_greedy_iterations,
          'greedy_ack_batch_size', greedy_ack_batch_size)

    for ackb_i in range(max_greedy_iterations):
        bs = 0
        hsic_scores = []
        condense_kernels = kernel_matrices[indices_to_condense].permute(
            [1, 2, 0]).mean(dim=-1, keepdim=True).unsqueeze(0)  # 1, K, K, 1
        while bs < B:
            be = min(B, bs + hsic_compute_batch_size)
            m = be - bs

            if batch_kernel is None:
                hsic_scores += [
                    hsic.total_hsic_parallel(
                        torch.cat([
                            condense_kernels.repeat([m, 1, 1, 1]),
                            kernel_matrices[bs:be].unsqueeze(-1),
                        ],
                                  dim=-1).to(device))
                ]
            else:
                hsic_scores += [
                    hsic.total_hsic_parallel(
                        torch.cat(
                            [
                                condense_kernels.repeat([m, 1, 1, 1
                                                         ]),  # M, K, K, 1
                                torch.cat(
                                    [
                                        batch_kernel.unsqueeze(0).repeat([
                                            m, 1, 1, 1
                                        ]),  # M, K, K, max_batch_compute_size
                                        kernel_matrices[bs:be].unsqueeze(
                                            -1),  # M, K, K, 1
                                    ],
                                    dim=-1).mean(dim=-1, keepdim=True),
                            ],
                            dim=-1).to(device))
                ]
            bs = be

        hsic_scores = torch.cat(hsic_scores)
        hsic_scores[ack_bag] = -math.inf

        _, sorted_idxes = torch.sort(hsic_scores, descending=True)
        winner_idxes = []
        g_ack_i = 0
        while len(winner_idxes) < greedy_ack_batch_size:
            assert g_ack_i < sorted_idxes.shape[0]
            idx = sorted_idxes[g_ack_i].item()
            g_ack_i += 1
            if idx in ack_bag:
                continue
            winner_idxes += [idx]

        ack_bag += winner_idxes
        global_acquisition_bag += [
            i.item()
            for i in result.subset_split.get_dataset_indices(winner_idxes)
        ]
        acquisition_bag_scores += [s.item() for s in hsic_scores[winner_idxes]]
        print('winner score', result.scores_B[winner_idxes].mean().item(),
              ', hsic_score', hsic_scores[winner_idxes].mean().item(),
              ', ackb_i', ackb_i)
        if batch_kernel is None:
            batch_kernel = kernel_matrices[winner_idxes].permute([1, 2, 0
                                                                  ])  # K, K, L
        else:
            batch_kernel = torch.cat([
                batch_kernel, kernel_matrices[winner_idxes].permute([1, 2, 0])
            ],
                                     dim=-1)  # K, K, ack_size
            assert len(batch_kernel.shape) == 3
        if batch_kernel.shape[
                -1] >= max_batch_compute_size and max_batch_compute_size != 0:
            idxes = np.random.choice(batch_kernel.shape[-1],
                                     size=max_batch_compute_size,
                                     replace=False)
            batch_kernel = batch_kernel[:, :, idxes]

        result.scores_B[winner_idxes] = -math.inf
        score_sort = torch.sort(result.scores_B, descending=True)
        score_sort_idx = score_sort[1]
        score_sort = score_sort[0]
        if hsic_resample:
            indices_to_condense = np.random.randint(
                low=0, high=score_sort_idx.shape[0], size=num_to_condense)

    assert len(ack_bag) == b
    np.set_printoptions(precision=3, suppress=True)

    end_time = time.process_time()
    time_taken = end_time - start_time
    print('ack time taken', time_taken)
    return AcquisitionBatch(global_acquisition_bag, acquisition_bag_scores,
                            None), time_taken
Пример #6
0
def compute_fass_batch(
    bayesian_model: nn.Module,
    available_loader,
    num_classes,
    k,
    b,
    target_size,
    initial_percentage,
    reduce_percentage,
    max_entropy_bag_size,
    fass_compute_batch_size,
    device=None,
) -> AcquisitionBatch:
    result = reduced_eval_consistent_bayesian_model(
        bayesian_model=bayesian_model,
        acquisition_function=AcquisitionFunction.bald,
        num_classes=num_classes,
        k=k,
        initial_percentage=initial_percentage,
        reduce_percentage=reduce_percentage,
        target_size=target_size,
        available_loader=available_loader,
        device=device,
    )

    start_time = time.process_time()

    B, K, C = list(result.logits_B_K_C.shape)
    probs_B_C = result.logits_B_K_C.exp_().mean(dim=1)
    preds_B = probs_B_C.max(dim=-1)[1]
    entropy = -(probs_B_C * probs_B_C.log()).sum(dim=-1)

    ack_bag = []
    global_acquisition_bag = []
    acquisition_bag_scores = []

    score_sort = torch.sort(entropy, descending=True)
    score_sort_idx = score_sort[1]
    score_sort = score_sort[0]

    cand_pts_idx = set(
        score_sort_idx[:max_entropy_bag_size].cpu().numpy().tolist())

    cand_X = []
    cand_X_preds = []
    cand_X_idx = []
    for i, (batch, labels) in enumerate(
            with_progress_bar(available_loader,
                              unit_scale=available_loader.batch_size)):
        lower = i * available_loader.batch_size
        upper = min(lower + available_loader.batch_size, B)
        idx_to_extract = np.array(list(
            set(range(lower, upper)).intersection(cand_pts_idx)),
                                  dtype=np.int32)
        cand_X_preds += [preds_B[idx_to_extract]]
        cand_X_idx += [torch.from_numpy(idx_to_extract).long()]
        idx_to_extract -= lower

        batch = batch.view(batch.shape[0], -1)  # batch_size x num_features
        cand_X += [batch[idx_to_extract]]

    cand_X = torch.cat(cand_X, dim=0).unsqueeze(1).to(device)
    cand_X_preds = torch.cat(cand_X_preds, dim=0).to(device)
    cand_X_idx = torch.cat(cand_X_idx, dim=0).to(device)

    num_cands = cand_X.shape[0]
    if num_cands > fass_compute_batch_size and fass_compute_batch_size > 0:
        sqdist = []
        for bs in range(0, num_cands, fass_compute_batch_size):
            be = min(num_cands, bs + fass_compute_batch_size)
            sqdist += [hsic.sqdist(cand_X[bs:be], cand_X).mean(-1).cpu()]
    else:
        sqdist = hsic.sqdist(cand_X,
                             cand_X).mean(-1)  # cand_X size x cand_X size
    sqdist = torch.cat(sqdist, dim=0).to(device)
    max_dist = sqdist.max()

    cand_min_dist = torch.ones((cand_X.shape[0], ), device=device) * max_dist
    ack_bag = []
    global_acquisition_bag = []
    for ackb_i in range(b):
        cand_distance = torch.ones(
            (cand_X.shape[0], ), device=device) * max_dist
        for c in range(C):
            cand_c_idx = cand_X_preds == c
            if cand_c_idx.long().sum() == 0:
                continue
            temp2 = []
            for bs in range(0, sqdist.shape[1], 5000):
                be = min(sqdist.shape[1], bs + 5000)
                bl = be - bs
                temp = torch.cat([
                    cand_min_dist[cand_c_idx].unsqueeze(-1).repeat([
                        1, bl
                    ]).unsqueeze(-1), sqdist[cand_c_idx, bs:be].unsqueeze(-1)
                ],
                                 dim=-1)
                temp2 += [torch.min(temp, dim=-1)[0].detach()]
                del temp
                torch.cuda.empty_cache()
            temp2 = torch.cat(temp2, dim=1).mean(1).detach()
            cand_distance[cand_c_idx] = temp2
        cand_distance[ack_bag] = max_dist
        winner_index = cand_distance.argmin().item()
        ack_bag += [winner_index]
        #print('cand_distance.shape', cand_distance.shape, winner_index, cand_X_idx.shape)
        winner_index = cand_X_idx[winner_index].item()
        global_acquisition_bag.append(
            result.subset_split.get_dataset_indices([winner_index]).item())

    assert len(ack_bag) == b
    np.set_printoptions(precision=3, suppress=True)
    #print('Acquired predictions')
    #for i in range(len(ack_bag)):
    #    print('ack_i', i, probs_B_K_C[ack_bag[i]].cpu().numpy())

    end_time = time.process_time()
    time_taken = end_time - start_time
    print('ack time taken', time_taken)
    return AcquisitionBatch(global_acquisition_bag, acquisition_bag_scores,
                            None), time_taken
Пример #7
0
def compute_multi_bald_batch(
    bayesian_model: nn.Module,
    available_loader,
    num_classes,
    k,  # Number of samples to use for monte carlo sampling
    b,  # Acquisition batch size (How many samples do we want to label next)
    target_size,
    initial_percentage,
    reduce_percentage,
    device=None,
) -> AcquisitionBatch:

    result = reduced_eval_consistent_bayesian_model(
        bayesian_model=bayesian_model,
        acquisition_function=AcquisitionFunction.
        bald,  # This is mutual information
        num_classes=num_classes,
        k=k,
        initial_percentage=initial_percentage,
        reduce_percentage=reduce_percentage,
        target_size=target_size,
        available_loader=available_loader,
        device=device,
    )
    # Result contains a certain amount of samples with the smallest mutual information
    subset_split = result.subset_split

    partial_multi_bald_B = result.scores_B
    # partial_multi_bald_B contais H(y_1, ..., y_n, y_m) -
    # E_p(w)[H(y_m|w)], n being the samples already in the aquisition
    # bag and m being all available samples that are candidates to be
    # selected into the aquisition bag. For the first sample to be
    # selcted, this is equivalent to H(y_m) - E_p(w)[H(y_m|w)], i.e.
    # the mutual information of y_m and the model parameters w. Since
    # E_p(w)[H(y_1, ..., y_n)] that has to be subtracted to get the
    # true result of a_BatchBALD is the same for all samples, we can
    # ignore it to find the best candidate

    # Now we can compute the conditional entropy
    conditional_entropies_B = joint_entropy_exact.batch_conditional_entropy_B(
        result.logits_B_K_C)
    # conditional_entropies_B = E_p(w)[H(y_i|w)]. After summing
    # together we get E_p(w)[H(y_1, ..., y_n|w)] which is the right
    # hand side of Equation 8 to calculate batchBALD

    # We turn the logits into probabilities.
    probs_B_K_C = result.logits_B_K_C.exp_()

    # Don't need the result anymore.
    result = None

    torch_utils.gc_cuda()

    with torch.no_grad():
        num_samples_per_ws = 40000 // k  # Number of samples used to calculate joint entropy for each sample of the model
        num_samples = num_samples_per_ws * k

        # Decide how many samples should be calculated at once when determining the joint entropy
        if device.type == "cuda":
            sample_MK_memory = num_samples * k * 8
            MC_memory = num_samples * num_classes * 8
            copy_buffer_memory = 256 * num_samples * num_classes * 8
            slack_memory = 2 * 2**30
            multi_bald_batch_size = (torch_utils.get_cuda_available_memory() -
                                     (sample_MK_memory + copy_buffer_memory +
                                      slack_memory)) // MC_memory

            global compute_multi_bald_bag_multi_bald_batch_size
            if compute_multi_bald_bag_multi_bald_batch_size != multi_bald_batch_size:
                compute_multi_bald_bag_multi_bald_batch_size = multi_bald_batch_size
                print(
                    f"New compute_multi_bald_bag_multi_bald_batch_size = {multi_bald_batch_size}"
                )
        else:
            multi_bald_batch_size = 16

        subset_acquisition_bag = [
        ]  # Indices of currently selected samples for next labeling (local indices)
        global_acquisition_bag = [
        ]  # Indices of currently selected samples for next labeling (global indices)
        acquisition_bag_scores = []

        # We use this for early-out in the b==0 case.
        MIN_SPREAD = 0.1

        if b == 0:
            b = 100
            early_out = True
        else:
            early_out = False

        prev_joint_probs_M_K = None
        prev_samples_M_K = None
        # Iteratively select b samples for labeling and put them in
        # the acquisition_bag
        for i in range(b):  # Algorithm 1 : Line number 2
            torch_utils.gc_cuda()

            if i > 0:  # Only run this starting from the second sample

                # Compute the joint entropies. Depending on the size
                # of n (y_1, ..., y_n) we can either solve this
                # analytically using joint_entropy.exact or via
                # sampling using joint_entropy.sample
                # The entropies can be calculated iteratively using information obtained when adding the last

                joint_entropies_B = torch.empty((len(probs_B_K_C), ),
                                                dtype=torch.float64)

                # If we can, calculate joint entropy analytically, otherwise use sampling
                exact_samples = num_classes**i
                if exact_samples <= num_samples:  # Use exact joint entropy (no sampling)
                    # P1:n-1?
                    prev_joint_probs_M_K = joint_entropy_exact.joint_probs_M_K(
                        probs_B_K_C[subset_acquisition_bag[-1]][None].to(
                            device),
                        prev_joint_probs_M_K=prev_joint_probs_M_K,
                    )

                    batch_exact_joint_entropy(
                        probs_B_K_C,  # Class probabilities from logits_B_K_C
                        prev_joint_probs_M_K,  # 
                        multi_bald_batch_size,  # Number of samples to compute at once
                        device,  # Calculate on GPU or CPU?
                        joint_entropies_B  # Filled with the resulting joint entropies
                    )
                else:  # use sampling to get joint entropy
                    if prev_joint_probs_M_K is not None:
                        prev_joint_probs_M_K = None
                        torch_utils.gc_cuda()

                    # Gather new traces for the new subset_acquisition_bag.
                    prev_samples_M_K = joint_entropy_sampling.sample_M_K(
                        probs_B_K_C[subset_acquisition_bag].to(device),
                        S=num_samples_per_ws)
                    # prev_samples_M_K is the probability of a
                    # certain label assignment configuration for all
                    # samples in the current acquisition_bag i.e. p(y^_1:n-1|w^_j) and therefore P^_{1:n-1}

                    for joint_entropies_b, probs_b_K_C in with_progress_bar(
                            torch_utils.split_tensors(joint_entropies_B,
                                                      probs_B_K_C,
                                                      multi_bald_batch_size),
                            unit_scale=multi_bald_batch_size,
                    ):
                        joint_entropies_b.copy_(joint_entropy_sampling.batch(
                            probs_b_K_C.to(device), prev_samples_M_K),
                                                non_blocking=True)

                    prev_samples_M_K = None
                    torch_utils.gc_cuda()

                partial_multi_bald_B = joint_entropies_B - conditional_entropies_B
                joint_entropies_B = None

            # Don't allow reselection
            partial_multi_bald_B[subset_acquisition_bag] = -math.inf

            # Algorithm 1 : Line 4

            winner_index = partial_multi_bald_B.argmax().item()

            # Actual MultiBALD is:
            actual_multi_bald_B = partial_multi_bald_B[
                winner_index] - torch.sum(
                    conditional_entropies_B[subset_acquisition_bag])
            actual_multi_bald_B = actual_multi_bald_B.item()

            print(f"Actual MultiBALD: {actual_multi_bald_B}")

            # If we early out, we don't take the point that triggers the early out.
            # Only allow early-out after acquiring at least 1 sample.
            if early_out and i > 1:
                current_spread = actual_multi_bald_B[
                    winner_index] - actual_multi_bald_B.median()
                if current_spread < MIN_SPREAD:
                    print("Early out")
                    break

            acquisition_bag_scores.append(actual_multi_bald_B)
            # Algorithm 1 : Line 5
            subset_acquisition_bag.append(winner_index)
            # We need to map the index back to the actual dataset.
            global_acquisition_bag.append(
                subset_split.get_dataset_indices([winner_index]).item())

            print(f"Acquisition bag: {sorted(global_acquisition_bag)}")

    return AcquisitionBatch(global_acquisition_bag, acquisition_bag_scores,
                            None)