Example #1
0
def batch(probs_B_K_C, samples_M_K):
    probs_B_K_C = probs_B_K_C.double()
    samples_M_K = samples_M_K.double()

    device = probs_B_K_C.device
    M, K = samples_M_K.shape
    B, K_, C = probs_B_K_C.shape
    assert K == K_

    p_B_M_C = torch.empty((B, M, C), dtype=torch.float64, device=device)

    for i in range(B):
        torch.matmul(samples_M_K, probs_B_K_C[i], out=p_B_M_C[i])

    p_B_M_C /= K

    q_1_M_1 = samples_M_K.mean(dim=1, keepdim=True)[None]

    # Now we can compute the entropy.
    # We store it directly on the CPU to save GPU memory.
    entropy_B = torch.zeros((B,), dtype=torch.float64)

    chunk_size = 256
    for entropy_b, p_b_M_C in split_tensors(entropy_B, p_B_M_C, chunk_size):
        entropy_b.copy_(importance_weighted_entropy_p_b_M_C(p_b_M_C, q_1_M_1, M), non_blocking=True)

    return entropy_B
    def compute_scores(self, logits_B_K_C, available_loader, device):
        """"""
        scorer = self.scorer

        if self == AcquisitionFunction.random:
            return scorer(logits_B_K_C, None).double()

        B, K, C = logits_B_K_C.shape

        # We need to sample the predictions from the bayesian_model n times and store them.
        with torch.no_grad():
            scores_B = torch.empty((B,), dtype=torch.float64)

            if device.type == "cuda":
                torch_utils.gc_cuda()
                KC_memory = K * C * 8
                batch_size = min(torch_utils.get_cuda_available_memory() // KC_memory, 8192)
            else:
                batch_size = 4096

            for scores_b, logits_b_K_C in with_progress_bar(
                torch_utils.split_tensors(scores_B, logits_B_K_C, batch_size), unit_scale=batch_size
            ):
                scores_b.copy_(scorer(logits_b_K_C.to(device)), non_blocking=True)

        return scores_B
Example #3
0
File: exact.py Project: tmfs10/ical
def batch(probs_B_K_C, prev_joint_probs_M_K=None):
    if prev_joint_probs_M_K is not None:
        assert prev_joint_probs_M_K.shape[1] == probs_B_K_C.shape[1]

    device = probs_B_K_C.device
    B, K, C = probs_B_K_C.shape
    probs_B_K_C = probs_B_K_C.double()

    if prev_joint_probs_M_K is None:
        prev_joint_probs_M_K = torch.ones((1, K),
                                          dtype=torch.float64,
                                          device=device)

    joint_probs_B_M_C = entropy_joint_probs_B_M_C(probs_B_K_C,
                                                  prev_joint_probs_M_K)

    # Now we can compute the entropy.
    entropy_B = torch.zeros((B, ), dtype=torch.float64, device=device)

    chunk_size = 256
    for entropy_b, joint_probs_b_M_C in split_tensors(entropy_B,
                                                      joint_probs_B_M_C,
                                                      chunk_size):
        entropy_b.copy_(entropy_from_probs_b_M_C(joint_probs_b_M_C),
                        non_blocking=True)

    return entropy_B
Example #4
0
def batch(probs_B_K_C, prev_joint_probs_M_K=None):
    if prev_joint_probs_M_K is not None:
        assert prev_joint_probs_M_K.shape[1] == probs_B_K_C.shape[1]

    device = probs_B_K_C.device
    B, K, C = probs_B_K_C.shape
    probs_B_K_C = probs_B_K_C.double()

    if prev_joint_probs_M_K is None:
        prev_joint_probs_M_K = torch.ones((1, K), dtype=torch.float64, device=device)

    M = prev_joint_probs_M_K.shape[0]
    joint_probs_B_M_C = torch.empty((B, M, C), dtype=torch.float64, device=device)

    for i in range(B):
        torch.matmul(prev_joint_probs_M_K, probs_B_K_C[i], out=joint_probs_B_M_C[i])

    joint_probs_B_M_C /= K

    # Now we can compute the entropy.
    entropy_B = torch.zeros((B,), dtype=torch.float64, device=device)

    chunk_size = 256
    for entropy_b, joint_probs_b_M_C in split_tensors(entropy_B, joint_probs_B_M_C, chunk_size):
        entropy_b.copy_(torch.sum(-joint_probs_b_M_C * torch.log(joint_probs_b_M_C), dim=(1, 2)), non_blocking=True)

    return entropy_B
Example #5
0
def batch_exact_joint_entropy(probs_B_K_C, prev_joint_probs_M_K, chunk_size,
                              device, out_joint_entropies_B):
    """This one switches between devices, too."""
    for joint_entropies_b, probs_b_K_C in tqdm(torch_utils.split_tensors(
            out_joint_entropies_B, probs_B_K_C, chunk_size),
                                               unit_scale=chunk_size):
        joint_entropies_b.copy_(joint_entropy_exact.batch(
            probs_b_K_C.to(device), prev_joint_probs_M_K),
                                non_blocking=True)

    return joint_entropies_b
Example #6
0
def batch_conditional_entropy_B(logits_B_K_C, out_conditional_entropy_B=None):
    B, K, C = logits_B_K_C.shape

    if out_conditional_entropy_B is None:
        out_conditional_entropy_B = torch.empty((B,), dtype=torch.float64)
    else:
        assert out_conditional_entropy_B.shape == (B,)

    for conditional_entropy_b, logits_b_K_C in split_tensors(out_conditional_entropy_B, logits_B_K_C, 8192):
        logits_b_K_C = logits_b_K_C.double()
        conditional_entropy_b.copy_(
            torch.sum(-logits_b_K_C * torch.exp(logits_b_K_C), dim=(1, 2)) / K, non_blocking=True
        )

    return out_conditional_entropy_B
Example #7
0
def compute_multi_bald_batch(
    bayesian_model: nn.Module,
    available_loader,
    num_classes,
    k,
    b,
    target_size,
    initial_percentage,
    reduce_percentage,
    device=None,
) -> AcquisitionBatch:
    result = reduced_eval_consistent_bayesian_model(
        bayesian_model=bayesian_model,
        acquisition_function=AcquisitionFunction.bald,
        num_classes=num_classes,
        k=k,
        initial_percentage=initial_percentage,
        reduce_percentage=reduce_percentage,
        target_size=target_size,
        available_loader=available_loader,
        device=device,
    )
    start_time = time.process_time()

    subset_split = result.subset_split

    partial_multi_bald_B = result.scores_B
    # Now we can compute the conditional entropy
    conditional_entropies_B = joint_entropy_exact.batch_conditional_entropy_B(
        result.logits_B_K_C)

    # We turn the logits into probabilities.
    probs_B_K_C = result.logits_B_K_C.exp_()

    # Don't need the result anymore.
    result = None

    torch_utils.gc_cuda()
    # torch_utils.cuda_meminfo()

    with torch.no_grad():
        num_samples_per_ws = 40000 // k
        num_samples = num_samples_per_ws * k

        if device.type == "cuda":
            # KC_memory = k*num_classes*8
            sample_MK_memory = num_samples * k * 8
            MC_memory = num_samples * num_classes * 8
            copy_buffer_memory = 256 * num_samples * num_classes * 8
            slack_memory = 2 * 2**30
            multi_bald_batch_size = (torch_utils.get_cuda_available_memory() -
                                     (sample_MK_memory + copy_buffer_memory +
                                      slack_memory)) // MC_memory

            global compute_multi_bald_bag_multi_bald_batch_size
            if compute_multi_bald_bag_multi_bald_batch_size != multi_bald_batch_size:
                compute_multi_bald_bag_multi_bald_batch_size = multi_bald_batch_size
                print(
                    f"New compute_multi_bald_bag_multi_bald_batch_size = {multi_bald_batch_size}"
                )
        else:
            multi_bald_batch_size = 16

        subset_acquisition_bag = []
        global_acquisition_bag = []
        acquisition_bag_scores = []

        # We use this for early-out in the b==0 case.
        MIN_SPREAD = 0.1

        if b == 0:
            b = 100
            early_out = True
        else:
            early_out = False

        prev_joint_probs_M_K = None
        prev_samples_M_K = None
        for i in range(b):
            torch_utils.gc_cuda()

            if i > 0:
                # Compute the joint entropy
                joint_entropies_B = torch.empty((len(probs_B_K_C), ),
                                                dtype=torch.float64)

                exact_samples = num_classes**i
                if exact_samples <= num_samples:
                    prev_joint_probs_M_K = joint_entropy_exact.joint_probs_M_K(
                        probs_B_K_C[subset_acquisition_bag[-1]][None].to(
                            device),
                        prev_joint_probs_M_K=prev_joint_probs_M_K,
                    )

                    # torch_utils.cuda_meminfo()
                    batch_exact_joint_entropy(probs_B_K_C,
                                              prev_joint_probs_M_K,
                                              multi_bald_batch_size, device,
                                              joint_entropies_B)
                else:
                    if prev_joint_probs_M_K is not None:
                        prev_joint_probs_M_K = None
                        torch_utils.gc_cuda()

                    # Gather new traces for the new subset_acquisition_bag.
                    prev_samples_M_K = joint_entropy_sampling.sample_M_K(
                        probs_B_K_C[subset_acquisition_bag].to(device),
                        S=num_samples_per_ws)

                    # torch_utils.cuda_meminfo()
                    for joint_entropies_b, probs_b_K_C in with_progress_bar(
                            torch_utils.split_tensors(joint_entropies_B,
                                                      probs_B_K_C,
                                                      multi_bald_batch_size),
                            unit_scale=multi_bald_batch_size,
                    ):
                        joint_entropies_b.copy_(joint_entropy_sampling.batch(
                            probs_b_K_C.to(device), prev_samples_M_K),
                                                non_blocking=True)

                        # torch_utils.cuda_meminfo()

                    prev_samples_M_K = None
                    torch_utils.gc_cuda()

                partial_multi_bald_B = joint_entropies_B - conditional_entropies_B
                joint_entropies_B = None

            # Don't allow reselection
            partial_multi_bald_B[subset_acquisition_bag] = -math.inf

            winner_index = partial_multi_bald_B.argmax().item()

            # Actual MultiBALD is:
            actual_multi_bald_B = partial_multi_bald_B[
                winner_index] - torch.sum(
                    conditional_entropies_B[subset_acquisition_bag])
            actual_multi_bald_B = actual_multi_bald_B.item()

            print(f"Actual MultiBALD: {actual_multi_bald_B}")

            # If we early out, we don't take the point that triggers the early out.
            # Only allow early-out after acquiring at least 1 sample.
            if early_out and i > 1:
                current_spread = actual_multi_bald_B[
                    winner_index] - actual_multi_bald_B.median()
                if current_spread < MIN_SPREAD:
                    print("Early out")
                    break

            acquisition_bag_scores.append(actual_multi_bald_B)

            subset_acquisition_bag.append(winner_index)
            # We need to map the index back to the actual dataset.
            global_acquisition_bag.append(
                subset_split.get_dataset_indices([winner_index]).item())

            print(
                f"Acquisition bag: {sorted(global_acquisition_bag)}, num_ack: {i}"
            )

    end_time = time.process_time()
    time_taken = end_time - start_time
    print('ack time taken', time_taken)

    return AcquisitionBatch(global_acquisition_bag, acquisition_bag_scores,
                            None), time_taken
Example #8
0
def compute_multi_bald_batch(
    bayesian_model: nn.Module,
    available_loader,
    num_classes,
    k,  # Number of samples to use for monte carlo sampling
    b,  # Acquisition batch size (How many samples do we want to label next)
    target_size,
    initial_percentage,
    reduce_percentage,
    device=None,
) -> AcquisitionBatch:

    result = reduced_eval_consistent_bayesian_model(
        bayesian_model=bayesian_model,
        acquisition_function=AcquisitionFunction.
        bald,  # This is mutual information
        num_classes=num_classes,
        k=k,
        initial_percentage=initial_percentage,
        reduce_percentage=reduce_percentage,
        target_size=target_size,
        available_loader=available_loader,
        device=device,
    )
    # Result contains a certain amount of samples with the smallest mutual information
    subset_split = result.subset_split

    partial_multi_bald_B = result.scores_B
    # partial_multi_bald_B contais H(y_1, ..., y_n, y_m) -
    # E_p(w)[H(y_m|w)], n being the samples already in the aquisition
    # bag and m being all available samples that are candidates to be
    # selected into the aquisition bag. For the first sample to be
    # selcted, this is equivalent to H(y_m) - E_p(w)[H(y_m|w)], i.e.
    # the mutual information of y_m and the model parameters w. Since
    # E_p(w)[H(y_1, ..., y_n)] that has to be subtracted to get the
    # true result of a_BatchBALD is the same for all samples, we can
    # ignore it to find the best candidate

    # Now we can compute the conditional entropy
    conditional_entropies_B = joint_entropy_exact.batch_conditional_entropy_B(
        result.logits_B_K_C)
    # conditional_entropies_B = E_p(w)[H(y_i|w)]. After summing
    # together we get E_p(w)[H(y_1, ..., y_n|w)] which is the right
    # hand side of Equation 8 to calculate batchBALD

    # We turn the logits into probabilities.
    probs_B_K_C = result.logits_B_K_C.exp_()

    # Don't need the result anymore.
    result = None

    torch_utils.gc_cuda()

    with torch.no_grad():
        num_samples_per_ws = 40000 // k  # Number of samples used to calculate joint entropy for each sample of the model
        num_samples = num_samples_per_ws * k

        # Decide how many samples should be calculated at once when determining the joint entropy
        if device.type == "cuda":
            sample_MK_memory = num_samples * k * 8
            MC_memory = num_samples * num_classes * 8
            copy_buffer_memory = 256 * num_samples * num_classes * 8
            slack_memory = 2 * 2**30
            multi_bald_batch_size = (torch_utils.get_cuda_available_memory() -
                                     (sample_MK_memory + copy_buffer_memory +
                                      slack_memory)) // MC_memory

            global compute_multi_bald_bag_multi_bald_batch_size
            if compute_multi_bald_bag_multi_bald_batch_size != multi_bald_batch_size:
                compute_multi_bald_bag_multi_bald_batch_size = multi_bald_batch_size
                print(
                    f"New compute_multi_bald_bag_multi_bald_batch_size = {multi_bald_batch_size}"
                )
        else:
            multi_bald_batch_size = 16

        subset_acquisition_bag = [
        ]  # Indices of currently selected samples for next labeling (local indices)
        global_acquisition_bag = [
        ]  # Indices of currently selected samples for next labeling (global indices)
        acquisition_bag_scores = []

        # We use this for early-out in the b==0 case.
        MIN_SPREAD = 0.1

        if b == 0:
            b = 100
            early_out = True
        else:
            early_out = False

        prev_joint_probs_M_K = None
        prev_samples_M_K = None
        # Iteratively select b samples for labeling and put them in
        # the acquisition_bag
        for i in range(b):  # Algorithm 1 : Line number 2
            torch_utils.gc_cuda()

            if i > 0:  # Only run this starting from the second sample

                # Compute the joint entropies. Depending on the size
                # of n (y_1, ..., y_n) we can either solve this
                # analytically using joint_entropy.exact or via
                # sampling using joint_entropy.sample
                # The entropies can be calculated iteratively using information obtained when adding the last

                joint_entropies_B = torch.empty((len(probs_B_K_C), ),
                                                dtype=torch.float64)

                # If we can, calculate joint entropy analytically, otherwise use sampling
                exact_samples = num_classes**i
                if exact_samples <= num_samples:  # Use exact joint entropy (no sampling)
                    # P1:n-1?
                    prev_joint_probs_M_K = joint_entropy_exact.joint_probs_M_K(
                        probs_B_K_C[subset_acquisition_bag[-1]][None].to(
                            device),
                        prev_joint_probs_M_K=prev_joint_probs_M_K,
                    )

                    batch_exact_joint_entropy(
                        probs_B_K_C,  # Class probabilities from logits_B_K_C
                        prev_joint_probs_M_K,  # 
                        multi_bald_batch_size,  # Number of samples to compute at once
                        device,  # Calculate on GPU or CPU?
                        joint_entropies_B  # Filled with the resulting joint entropies
                    )
                else:  # use sampling to get joint entropy
                    if prev_joint_probs_M_K is not None:
                        prev_joint_probs_M_K = None
                        torch_utils.gc_cuda()

                    # Gather new traces for the new subset_acquisition_bag.
                    prev_samples_M_K = joint_entropy_sampling.sample_M_K(
                        probs_B_K_C[subset_acquisition_bag].to(device),
                        S=num_samples_per_ws)
                    # prev_samples_M_K is the probability of a
                    # certain label assignment configuration for all
                    # samples in the current acquisition_bag i.e. p(y^_1:n-1|w^_j) and therefore P^_{1:n-1}

                    for joint_entropies_b, probs_b_K_C in with_progress_bar(
                            torch_utils.split_tensors(joint_entropies_B,
                                                      probs_B_K_C,
                                                      multi_bald_batch_size),
                            unit_scale=multi_bald_batch_size,
                    ):
                        joint_entropies_b.copy_(joint_entropy_sampling.batch(
                            probs_b_K_C.to(device), prev_samples_M_K),
                                                non_blocking=True)

                    prev_samples_M_K = None
                    torch_utils.gc_cuda()

                partial_multi_bald_B = joint_entropies_B - conditional_entropies_B
                joint_entropies_B = None

            # Don't allow reselection
            partial_multi_bald_B[subset_acquisition_bag] = -math.inf

            # Algorithm 1 : Line 4

            winner_index = partial_multi_bald_B.argmax().item()

            # Actual MultiBALD is:
            actual_multi_bald_B = partial_multi_bald_B[
                winner_index] - torch.sum(
                    conditional_entropies_B[subset_acquisition_bag])
            actual_multi_bald_B = actual_multi_bald_B.item()

            print(f"Actual MultiBALD: {actual_multi_bald_B}")

            # If we early out, we don't take the point that triggers the early out.
            # Only allow early-out after acquiring at least 1 sample.
            if early_out and i > 1:
                current_spread = actual_multi_bald_B[
                    winner_index] - actual_multi_bald_B.median()
                if current_spread < MIN_SPREAD:
                    print("Early out")
                    break

            acquisition_bag_scores.append(actual_multi_bald_B)
            # Algorithm 1 : Line 5
            subset_acquisition_bag.append(winner_index)
            # We need to map the index back to the actual dataset.
            global_acquisition_bag.append(
                subset_split.get_dataset_indices([winner_index]).item())

            print(f"Acquisition bag: {sorted(global_acquisition_bag)}")

    return AcquisitionBatch(global_acquisition_bag, acquisition_bag_scores,
                            None)