def compute_scores(self, logits_B_K_C, available_loader, device):
        """"""
        scorer = self.scorer

        if self == AcquisitionFunction.random:
            return scorer(logits_B_K_C, None).double()

        B, K, C = logits_B_K_C.shape

        # We need to sample the predictions from the bayesian_model n times and store them.
        with torch.no_grad():
            scores_B = torch.empty((B,), dtype=torch.float64)

            if device.type == "cuda":
                torch_utils.gc_cuda()
                KC_memory = K * C * 8
                batch_size = min(torch_utils.get_cuda_available_memory() // KC_memory, 8192)
            else:
                batch_size = 4096

            for scores_b, logits_b_K_C in with_progress_bar(
                torch_utils.split_tensors(scores_B, logits_B_K_C, batch_size), unit_scale=batch_size
            ):
                scores_b.copy_(scorer(logits_b_K_C.to(device)), non_blocking=True)

        return scores_B
Example #2
0
def batch_exact_joint_entropy(probs_B_K_C, prev_joint_probs_M_K, chunk_size,
                              device, out_joint_entropies_B):
    """This one switches between devices, too."""
    for joint_entropies_b, probs_b_K_C in with_progress_bar(
            torch_utils.split_tensors(out_joint_entropies_B, probs_B_K_C,
                                      chunk_size),
            unit_scale=chunk_size):
        joint_entropies_b.copy_(joint_entropy_exact.batch(
            probs_b_K_C.to(device), prev_joint_probs_M_K),
                                non_blocking=True)

    return joint_entropies_b
Example #3
0
def eval_bayesian_model_consistent(
    bayesian_model: mc_dropout.BayesianModule, available_loader, num_classes, k=20, device=None
):
    global eval_bayesian_model_consistent_cuda_chunk_size

    with torch.no_grad():
        # NOTE: I'm hard-coding 10 classes here!
        B = len(available_loader.dataset)
        logits_B_K_C = torch.empty((B, k, num_classes), dtype=torch.float64)

        chunk_size = eval_bayesian_model_consistent_cuda_chunk_size if device.type == "cuda" else 64
        torch_utils.gc_cuda()
        k_lower = 0
        while k_lower < k:
            try:
                k_upper = min(k_lower + chunk_size, k)

                # This resets the dropout masks.
                bayesian_model.eval()

                for i, (batch, _) in enumerate(
                    with_progress_bar(available_loader, unit_scale=available_loader.batch_size)
                ):
                    lower = i * available_loader.batch_size
                    upper = min(lower + available_loader.batch_size, B)

                    batch = batch.to(device)
                    # batch_size x ws x classes
                    mc_output_B_K_C = bayesian_model(batch, k_upper - k_lower)
                    logits_B_K_C[lower:upper, k_lower:k_upper].copy_(mc_output_B_K_C.double(), non_blocking=True)

            except RuntimeError as exception:
                if torch_utils.should_reduce_batch_size(exception):
                    if chunk_size <= 1:
                        raise
                    chunk_size //= 2
                    print(f"New eval_bayesian_model_consistent_cuda_chunk_size={chunk_size} ({exception})")
                    eval_bayesian_model_consistent_cuda_chunk_size = chunk_size

                    torch_utils.gc_cuda()
                else:
                    raise
            else:
                k_lower += chunk_size

    return logits_B_K_C
Example #4
0
    def run_parser(self):
        """Initiate the check against the zone file.

        Returns:
            {list} -- list of matches
        """
        self.results = []
        if self.debug:
            logger.info("Getting zonefile info for: %s" % self.filename)
        if self.parse_zonefile():
            for name in with_progress_bar(self.names):
                self.check_ip(name=name)
                pass
            if self.debug:
                logger.info("Done!")
        if len(self.results) == 0:
            return None
        else:
            logger.debug("Results: %s" % self.results)
            return self.results
Example #5
0
def compute_multi_bald_batch(
    bayesian_model: nn.Module,
    available_loader,
    num_classes,
    k,
    b,
    target_size,
    initial_percentage,
    reduce_percentage,
    device=None,
) -> AcquisitionBatch:
    result = reduced_eval_consistent_bayesian_model(
        bayesian_model=bayesian_model,
        acquisition_function=AcquisitionFunction.bald,
        num_classes=num_classes,
        k=k,
        initial_percentage=initial_percentage,
        reduce_percentage=reduce_percentage,
        target_size=target_size,
        available_loader=available_loader,
        device=device,
    )
    start_time = time.process_time()

    subset_split = result.subset_split

    partial_multi_bald_B = result.scores_B
    # Now we can compute the conditional entropy
    conditional_entropies_B = joint_entropy_exact.batch_conditional_entropy_B(
        result.logits_B_K_C)

    # We turn the logits into probabilities.
    probs_B_K_C = result.logits_B_K_C.exp_()

    # Don't need the result anymore.
    result = None

    torch_utils.gc_cuda()
    # torch_utils.cuda_meminfo()

    with torch.no_grad():
        num_samples_per_ws = 40000 // k
        num_samples = num_samples_per_ws * k

        if device.type == "cuda":
            # KC_memory = k*num_classes*8
            sample_MK_memory = num_samples * k * 8
            MC_memory = num_samples * num_classes * 8
            copy_buffer_memory = 256 * num_samples * num_classes * 8
            slack_memory = 2 * 2**30
            multi_bald_batch_size = (torch_utils.get_cuda_available_memory() -
                                     (sample_MK_memory + copy_buffer_memory +
                                      slack_memory)) // MC_memory

            global compute_multi_bald_bag_multi_bald_batch_size
            if compute_multi_bald_bag_multi_bald_batch_size != multi_bald_batch_size:
                compute_multi_bald_bag_multi_bald_batch_size = multi_bald_batch_size
                print(
                    f"New compute_multi_bald_bag_multi_bald_batch_size = {multi_bald_batch_size}"
                )
        else:
            multi_bald_batch_size = 16

        subset_acquisition_bag = []
        global_acquisition_bag = []
        acquisition_bag_scores = []

        # We use this for early-out in the b==0 case.
        MIN_SPREAD = 0.1

        if b == 0:
            b = 100
            early_out = True
        else:
            early_out = False

        prev_joint_probs_M_K = None
        prev_samples_M_K = None
        for i in range(b):
            torch_utils.gc_cuda()

            if i > 0:
                # Compute the joint entropy
                joint_entropies_B = torch.empty((len(probs_B_K_C), ),
                                                dtype=torch.float64)

                exact_samples = num_classes**i
                if exact_samples <= num_samples:
                    prev_joint_probs_M_K = joint_entropy_exact.joint_probs_M_K(
                        probs_B_K_C[subset_acquisition_bag[-1]][None].to(
                            device),
                        prev_joint_probs_M_K=prev_joint_probs_M_K,
                    )

                    # torch_utils.cuda_meminfo()
                    batch_exact_joint_entropy(probs_B_K_C,
                                              prev_joint_probs_M_K,
                                              multi_bald_batch_size, device,
                                              joint_entropies_B)
                else:
                    if prev_joint_probs_M_K is not None:
                        prev_joint_probs_M_K = None
                        torch_utils.gc_cuda()

                    # Gather new traces for the new subset_acquisition_bag.
                    prev_samples_M_K = joint_entropy_sampling.sample_M_K(
                        probs_B_K_C[subset_acquisition_bag].to(device),
                        S=num_samples_per_ws)

                    # torch_utils.cuda_meminfo()
                    for joint_entropies_b, probs_b_K_C in with_progress_bar(
                            torch_utils.split_tensors(joint_entropies_B,
                                                      probs_B_K_C,
                                                      multi_bald_batch_size),
                            unit_scale=multi_bald_batch_size,
                    ):
                        joint_entropies_b.copy_(joint_entropy_sampling.batch(
                            probs_b_K_C.to(device), prev_samples_M_K),
                                                non_blocking=True)

                        # torch_utils.cuda_meminfo()

                    prev_samples_M_K = None
                    torch_utils.gc_cuda()

                partial_multi_bald_B = joint_entropies_B - conditional_entropies_B
                joint_entropies_B = None

            # Don't allow reselection
            partial_multi_bald_B[subset_acquisition_bag] = -math.inf

            winner_index = partial_multi_bald_B.argmax().item()

            # Actual MultiBALD is:
            actual_multi_bald_B = partial_multi_bald_B[
                winner_index] - torch.sum(
                    conditional_entropies_B[subset_acquisition_bag])
            actual_multi_bald_B = actual_multi_bald_B.item()

            print(f"Actual MultiBALD: {actual_multi_bald_B}")

            # If we early out, we don't take the point that triggers the early out.
            # Only allow early-out after acquiring at least 1 sample.
            if early_out and i > 1:
                current_spread = actual_multi_bald_B[
                    winner_index] - actual_multi_bald_B.median()
                if current_spread < MIN_SPREAD:
                    print("Early out")
                    break

            acquisition_bag_scores.append(actual_multi_bald_B)

            subset_acquisition_bag.append(winner_index)
            # We need to map the index back to the actual dataset.
            global_acquisition_bag.append(
                subset_split.get_dataset_indices([winner_index]).item())

            print(
                f"Acquisition bag: {sorted(global_acquisition_bag)}, num_ack: {i}"
            )

    end_time = time.process_time()
    time_taken = end_time - start_time
    print('ack time taken', time_taken)

    return AcquisitionBatch(global_acquisition_bag, acquisition_bag_scores,
                            None), time_taken
Example #6
0
def compute_fass_batch(
    bayesian_model: nn.Module,
    available_loader,
    num_classes,
    k,
    b,
    target_size,
    initial_percentage,
    reduce_percentage,
    max_entropy_bag_size,
    fass_compute_batch_size,
    device=None,
) -> AcquisitionBatch:
    result = reduced_eval_consistent_bayesian_model(
        bayesian_model=bayesian_model,
        acquisition_function=AcquisitionFunction.bald,
        num_classes=num_classes,
        k=k,
        initial_percentage=initial_percentage,
        reduce_percentage=reduce_percentage,
        target_size=target_size,
        available_loader=available_loader,
        device=device,
    )

    start_time = time.process_time()

    B, K, C = list(result.logits_B_K_C.shape)
    probs_B_C = result.logits_B_K_C.exp_().mean(dim=1)
    preds_B = probs_B_C.max(dim=-1)[1]
    entropy = -(probs_B_C * probs_B_C.log()).sum(dim=-1)

    ack_bag = []
    global_acquisition_bag = []
    acquisition_bag_scores = []

    score_sort = torch.sort(entropy, descending=True)
    score_sort_idx = score_sort[1]
    score_sort = score_sort[0]

    cand_pts_idx = set(
        score_sort_idx[:max_entropy_bag_size].cpu().numpy().tolist())

    cand_X = []
    cand_X_preds = []
    cand_X_idx = []
    for i, (batch, labels) in enumerate(
            with_progress_bar(available_loader,
                              unit_scale=available_loader.batch_size)):
        lower = i * available_loader.batch_size
        upper = min(lower + available_loader.batch_size, B)
        idx_to_extract = np.array(list(
            set(range(lower, upper)).intersection(cand_pts_idx)),
                                  dtype=np.int32)
        cand_X_preds += [preds_B[idx_to_extract]]
        cand_X_idx += [torch.from_numpy(idx_to_extract).long()]
        idx_to_extract -= lower

        batch = batch.view(batch.shape[0], -1)  # batch_size x num_features
        cand_X += [batch[idx_to_extract]]

    cand_X = torch.cat(cand_X, dim=0).unsqueeze(1).to(device)
    cand_X_preds = torch.cat(cand_X_preds, dim=0).to(device)
    cand_X_idx = torch.cat(cand_X_idx, dim=0).to(device)

    num_cands = cand_X.shape[0]
    if num_cands > fass_compute_batch_size and fass_compute_batch_size > 0:
        sqdist = []
        for bs in range(0, num_cands, fass_compute_batch_size):
            be = min(num_cands, bs + fass_compute_batch_size)
            sqdist += [hsic.sqdist(cand_X[bs:be], cand_X).mean(-1).cpu()]
    else:
        sqdist = hsic.sqdist(cand_X,
                             cand_X).mean(-1)  # cand_X size x cand_X size
    sqdist = torch.cat(sqdist, dim=0).to(device)
    max_dist = sqdist.max()

    cand_min_dist = torch.ones((cand_X.shape[0], ), device=device) * max_dist
    ack_bag = []
    global_acquisition_bag = []
    for ackb_i in range(b):
        cand_distance = torch.ones(
            (cand_X.shape[0], ), device=device) * max_dist
        for c in range(C):
            cand_c_idx = cand_X_preds == c
            if cand_c_idx.long().sum() == 0:
                continue
            temp2 = []
            for bs in range(0, sqdist.shape[1], 5000):
                be = min(sqdist.shape[1], bs + 5000)
                bl = be - bs
                temp = torch.cat([
                    cand_min_dist[cand_c_idx].unsqueeze(-1).repeat([
                        1, bl
                    ]).unsqueeze(-1), sqdist[cand_c_idx, bs:be].unsqueeze(-1)
                ],
                                 dim=-1)
                temp2 += [torch.min(temp, dim=-1)[0].detach()]
                del temp
                torch.cuda.empty_cache()
            temp2 = torch.cat(temp2, dim=1).mean(1).detach()
            cand_distance[cand_c_idx] = temp2
        cand_distance[ack_bag] = max_dist
        winner_index = cand_distance.argmin().item()
        ack_bag += [winner_index]
        #print('cand_distance.shape', cand_distance.shape, winner_index, cand_X_idx.shape)
        winner_index = cand_X_idx[winner_index].item()
        global_acquisition_bag.append(
            result.subset_split.get_dataset_indices([winner_index]).item())

    assert len(ack_bag) == b
    np.set_printoptions(precision=3, suppress=True)
    #print('Acquired predictions')
    #for i in range(len(ack_bag)):
    #    print('ack_i', i, probs_B_K_C[ack_bag[i]].cpu().numpy())

    end_time = time.process_time()
    time_taken = end_time - start_time
    print('ack time taken', time_taken)
    return AcquisitionBatch(global_acquisition_bag, acquisition_bag_scores,
                            None), time_taken
def reduced_eval_consistent_bayesian_model(
    bayesian_model: mc_dropout.BayesianModule,
    acquisition_function: AcquisitionFunction,
    num_classes: int,
    k: int,
    initial_percentage: int,
    reduce_percentage: int,
    target_size: int,
    available_loader,
    device=None,
) -> SubsetEvalResults:
    """Performs a scoring step with k inference samples while reducing the dataset to at most min_remaining_percentage.

    Before computing anything at all the initial available dataset is randomly culled to initial_percentage.

    Every `chunk_size` inferences BALD is recomputed and the bottom `reduce_percentage` samples are dropped."""
    global reduced_eval_consistent_bayesian_model_cuda_chunk_size

    # TODO: ActiveLearningData should be renamed to be a more modular SplitDataset.
    # Here, we need to use available_dataset because it allows us to easily recover the original indices.

    # We start with all data in the acquired data.
    subset_split = active_learning_data.ActiveLearningData(
        available_loader.dataset)
    initial_length = len(available_loader.dataset)

    initial_split_length = initial_length * initial_percentage // 100

    # By acquiring [initial_split_length:], we make the tail unavailable.
    subset_split.acquire(torch.randperm(initial_length)[initial_split_length:])

    subset_dataloader = data.DataLoader(subset_split.available_dataset,
                                        shuffle=False,
                                        batch_size=available_loader.batch_size)

    print(f"Scoring subset of {len(subset_dataloader.dataset)} items:")

    # We're done with available_loader in this function.
    available_loader = None

    with torch.no_grad():
        B = len(subset_split.available_dataset)
        C = num_classes

        # We stay on the CPU.
        logits_B_K_C = None

        k_lower = 0
        torch_utils.gc_cuda()
        chunk_size = reduced_eval_consistent_bayesian_model_cuda_chunk_size if device.type == "cuda" else 32
        while k_lower < k:
            try:
                k_upper = min(k_lower + chunk_size, k)

                old_logit_B_K_C = logits_B_K_C
                # This also stays on the CPU.
                logits_B_K_C = torch.empty((B, k_upper, C),
                                           dtype=torch.float64)

                # Copy the old data over.
                if k_lower > 0:
                    logits_B_K_C[:, 0:k_lower, :].copy_(old_logit_B_K_C)
                    old_logit_B_K_C = None

                # This resets the dropout masks.
                bayesian_model.eval()

                for i, (batch, _) in enumerate(
                        with_progress_bar(
                            subset_dataloader,
                            unit_scale=subset_dataloader.batch_size)):
                    lower = i * subset_dataloader.batch_size
                    upper = min(lower + subset_dataloader.batch_size, B)

                    batch = batch.to(device)
                    # batch_size x ws x classes
                    mc_output_B_K_C = bayesian_model(batch, k_upper - k_lower)
                    logits_B_K_C[lower:upper, k_lower:k_upper].copy_(
                        mc_output_B_K_C.double(), non_blocking=True)

            except RuntimeError as exception:
                if torch_utils.should_reduce_batch_size(exception):
                    if chunk_size <= 1:
                        raise
                    chunk_size = chunk_size // 2
                    print(
                        f"New reduced_eval_consistent_bayesian_model_cuda_chunk_size={chunk_size} ({exception})"
                    )
                    reduced_eval_consistent_bayesian_model_cuda_chunk_size = chunk_size

                    torch_utils.gc_cuda()
                else:
                    raise
            else:
                if k_upper == k:
                    next_size = target_size
                elif k_upper < 50:
                    next_size = B
                else:
                    next_size = max(target_size,
                                    B * (100 - reduce_percentage) // 100)

                # Compute the score if it's needed: we are going to reduce the dataset or we're in the last iteration.
                if next_size < B or k_upper == k:
                    scores_B = acquisition_function.compute_scores(
                        logits_B_K_C,
                        available_loader=subset_dataloader,
                        device=device)
                else:
                    scores_B = None

                if next_size < B:
                    print("Reducing size", next_size)
                    sorted_indices = torch.argsort(scores_B, descending=True)

                    new_indices = torch.sort(sorted_indices[:next_size],
                                             descending=False)[0]

                    B = next_size
                    logits_B_K_C = logits_B_K_C[new_indices]
                    if k_upper == k:
                        logits_B_K_C = logits_B_K_C.clone().detach()
                    scores_B = scores_B[new_indices].clone().detach()

                    # Acquire all the low scorers
                    subset_split.acquire(sorted_indices[next_size:])

                k_lower += chunk_size

    return SubsetEvalResults(subset_split=subset_split,
                             subset_dataloader=subset_dataloader,
                             scores_B=scores_B,
                             logits_B_K_C=logits_B_K_C)
Example #8
0
def compute_multi_bald_batch(
    bayesian_model: nn.Module,
    available_loader,
    num_classes,
    k,  # Number of samples to use for monte carlo sampling
    b,  # Acquisition batch size (How many samples do we want to label next)
    target_size,
    initial_percentage,
    reduce_percentage,
    device=None,
) -> AcquisitionBatch:

    result = reduced_eval_consistent_bayesian_model(
        bayesian_model=bayesian_model,
        acquisition_function=AcquisitionFunction.
        bald,  # This is mutual information
        num_classes=num_classes,
        k=k,
        initial_percentage=initial_percentage,
        reduce_percentage=reduce_percentage,
        target_size=target_size,
        available_loader=available_loader,
        device=device,
    )
    # Result contains a certain amount of samples with the smallest mutual information
    subset_split = result.subset_split

    partial_multi_bald_B = result.scores_B
    # partial_multi_bald_B contais H(y_1, ..., y_n, y_m) -
    # E_p(w)[H(y_m|w)], n being the samples already in the aquisition
    # bag and m being all available samples that are candidates to be
    # selected into the aquisition bag. For the first sample to be
    # selcted, this is equivalent to H(y_m) - E_p(w)[H(y_m|w)], i.e.
    # the mutual information of y_m and the model parameters w. Since
    # E_p(w)[H(y_1, ..., y_n)] that has to be subtracted to get the
    # true result of a_BatchBALD is the same for all samples, we can
    # ignore it to find the best candidate

    # Now we can compute the conditional entropy
    conditional_entropies_B = joint_entropy_exact.batch_conditional_entropy_B(
        result.logits_B_K_C)
    # conditional_entropies_B = E_p(w)[H(y_i|w)]. After summing
    # together we get E_p(w)[H(y_1, ..., y_n|w)] which is the right
    # hand side of Equation 8 to calculate batchBALD

    # We turn the logits into probabilities.
    probs_B_K_C = result.logits_B_K_C.exp_()

    # Don't need the result anymore.
    result = None

    torch_utils.gc_cuda()

    with torch.no_grad():
        num_samples_per_ws = 40000 // k  # Number of samples used to calculate joint entropy for each sample of the model
        num_samples = num_samples_per_ws * k

        # Decide how many samples should be calculated at once when determining the joint entropy
        if device.type == "cuda":
            sample_MK_memory = num_samples * k * 8
            MC_memory = num_samples * num_classes * 8
            copy_buffer_memory = 256 * num_samples * num_classes * 8
            slack_memory = 2 * 2**30
            multi_bald_batch_size = (torch_utils.get_cuda_available_memory() -
                                     (sample_MK_memory + copy_buffer_memory +
                                      slack_memory)) // MC_memory

            global compute_multi_bald_bag_multi_bald_batch_size
            if compute_multi_bald_bag_multi_bald_batch_size != multi_bald_batch_size:
                compute_multi_bald_bag_multi_bald_batch_size = multi_bald_batch_size
                print(
                    f"New compute_multi_bald_bag_multi_bald_batch_size = {multi_bald_batch_size}"
                )
        else:
            multi_bald_batch_size = 16

        subset_acquisition_bag = [
        ]  # Indices of currently selected samples for next labeling (local indices)
        global_acquisition_bag = [
        ]  # Indices of currently selected samples for next labeling (global indices)
        acquisition_bag_scores = []

        # We use this for early-out in the b==0 case.
        MIN_SPREAD = 0.1

        if b == 0:
            b = 100
            early_out = True
        else:
            early_out = False

        prev_joint_probs_M_K = None
        prev_samples_M_K = None
        # Iteratively select b samples for labeling and put them in
        # the acquisition_bag
        for i in range(b):  # Algorithm 1 : Line number 2
            torch_utils.gc_cuda()

            if i > 0:  # Only run this starting from the second sample

                # Compute the joint entropies. Depending on the size
                # of n (y_1, ..., y_n) we can either solve this
                # analytically using joint_entropy.exact or via
                # sampling using joint_entropy.sample
                # The entropies can be calculated iteratively using information obtained when adding the last

                joint_entropies_B = torch.empty((len(probs_B_K_C), ),
                                                dtype=torch.float64)

                # If we can, calculate joint entropy analytically, otherwise use sampling
                exact_samples = num_classes**i
                if exact_samples <= num_samples:  # Use exact joint entropy (no sampling)
                    # P1:n-1?
                    prev_joint_probs_M_K = joint_entropy_exact.joint_probs_M_K(
                        probs_B_K_C[subset_acquisition_bag[-1]][None].to(
                            device),
                        prev_joint_probs_M_K=prev_joint_probs_M_K,
                    )

                    batch_exact_joint_entropy(
                        probs_B_K_C,  # Class probabilities from logits_B_K_C
                        prev_joint_probs_M_K,  # 
                        multi_bald_batch_size,  # Number of samples to compute at once
                        device,  # Calculate on GPU or CPU?
                        joint_entropies_B  # Filled with the resulting joint entropies
                    )
                else:  # use sampling to get joint entropy
                    if prev_joint_probs_M_K is not None:
                        prev_joint_probs_M_K = None
                        torch_utils.gc_cuda()

                    # Gather new traces for the new subset_acquisition_bag.
                    prev_samples_M_K = joint_entropy_sampling.sample_M_K(
                        probs_B_K_C[subset_acquisition_bag].to(device),
                        S=num_samples_per_ws)
                    # prev_samples_M_K is the probability of a
                    # certain label assignment configuration for all
                    # samples in the current acquisition_bag i.e. p(y^_1:n-1|w^_j) and therefore P^_{1:n-1}

                    for joint_entropies_b, probs_b_K_C in with_progress_bar(
                            torch_utils.split_tensors(joint_entropies_B,
                                                      probs_B_K_C,
                                                      multi_bald_batch_size),
                            unit_scale=multi_bald_batch_size,
                    ):
                        joint_entropies_b.copy_(joint_entropy_sampling.batch(
                            probs_b_K_C.to(device), prev_samples_M_K),
                                                non_blocking=True)

                    prev_samples_M_K = None
                    torch_utils.gc_cuda()

                partial_multi_bald_B = joint_entropies_B - conditional_entropies_B
                joint_entropies_B = None

            # Don't allow reselection
            partial_multi_bald_B[subset_acquisition_bag] = -math.inf

            # Algorithm 1 : Line 4

            winner_index = partial_multi_bald_B.argmax().item()

            # Actual MultiBALD is:
            actual_multi_bald_B = partial_multi_bald_B[
                winner_index] - torch.sum(
                    conditional_entropies_B[subset_acquisition_bag])
            actual_multi_bald_B = actual_multi_bald_B.item()

            print(f"Actual MultiBALD: {actual_multi_bald_B}")

            # If we early out, we don't take the point that triggers the early out.
            # Only allow early-out after acquiring at least 1 sample.
            if early_out and i > 1:
                current_spread = actual_multi_bald_B[
                    winner_index] - actual_multi_bald_B.median()
                if current_spread < MIN_SPREAD:
                    print("Early out")
                    break

            acquisition_bag_scores.append(actual_multi_bald_B)
            # Algorithm 1 : Line 5
            subset_acquisition_bag.append(winner_index)
            # We need to map the index back to the actual dataset.
            global_acquisition_bag.append(
                subset_split.get_dataset_indices([winner_index]).item())

            print(f"Acquisition bag: {sorted(global_acquisition_bag)}")

    return AcquisitionBatch(global_acquisition_bag, acquisition_bag_scores,
                            None)