def batch(probs_B_K_C, samples_M_K): probs_B_K_C = probs_B_K_C.double() samples_M_K = samples_M_K.double() device = probs_B_K_C.device M, K = samples_M_K.shape B, K_, C = probs_B_K_C.shape assert K == K_ p_B_M_C = torch.empty((B, M, C), dtype=torch.float64, device=device) for i in range(B): torch.matmul(samples_M_K, probs_B_K_C[i], out=p_B_M_C[i]) p_B_M_C /= K q_1_M_1 = samples_M_K.mean(dim=1, keepdim=True)[None] # Now we can compute the entropy. # We store it directly on the CPU to save GPU memory. entropy_B = torch.zeros((B,), dtype=torch.float64) chunk_size = 256 for entropy_b, p_b_M_C in split_tensors(entropy_B, p_B_M_C, chunk_size): entropy_b.copy_(importance_weighted_entropy_p_b_M_C(p_b_M_C, q_1_M_1, M), non_blocking=True) return entropy_B
def compute_scores(self, logits_B_K_C, available_loader, device): """""" scorer = self.scorer if self == AcquisitionFunction.random: return scorer(logits_B_K_C, None).double() B, K, C = logits_B_K_C.shape # We need to sample the predictions from the bayesian_model n times and store them. with torch.no_grad(): scores_B = torch.empty((B,), dtype=torch.float64) if device.type == "cuda": torch_utils.gc_cuda() KC_memory = K * C * 8 batch_size = min(torch_utils.get_cuda_available_memory() // KC_memory, 8192) else: batch_size = 4096 for scores_b, logits_b_K_C in with_progress_bar( torch_utils.split_tensors(scores_B, logits_B_K_C, batch_size), unit_scale=batch_size ): scores_b.copy_(scorer(logits_b_K_C.to(device)), non_blocking=True) return scores_B
def batch(probs_B_K_C, prev_joint_probs_M_K=None): if prev_joint_probs_M_K is not None: assert prev_joint_probs_M_K.shape[1] == probs_B_K_C.shape[1] device = probs_B_K_C.device B, K, C = probs_B_K_C.shape probs_B_K_C = probs_B_K_C.double() if prev_joint_probs_M_K is None: prev_joint_probs_M_K = torch.ones((1, K), dtype=torch.float64, device=device) joint_probs_B_M_C = entropy_joint_probs_B_M_C(probs_B_K_C, prev_joint_probs_M_K) # Now we can compute the entropy. entropy_B = torch.zeros((B, ), dtype=torch.float64, device=device) chunk_size = 256 for entropy_b, joint_probs_b_M_C in split_tensors(entropy_B, joint_probs_B_M_C, chunk_size): entropy_b.copy_(entropy_from_probs_b_M_C(joint_probs_b_M_C), non_blocking=True) return entropy_B
def batch(probs_B_K_C, prev_joint_probs_M_K=None): if prev_joint_probs_M_K is not None: assert prev_joint_probs_M_K.shape[1] == probs_B_K_C.shape[1] device = probs_B_K_C.device B, K, C = probs_B_K_C.shape probs_B_K_C = probs_B_K_C.double() if prev_joint_probs_M_K is None: prev_joint_probs_M_K = torch.ones((1, K), dtype=torch.float64, device=device) M = prev_joint_probs_M_K.shape[0] joint_probs_B_M_C = torch.empty((B, M, C), dtype=torch.float64, device=device) for i in range(B): torch.matmul(prev_joint_probs_M_K, probs_B_K_C[i], out=joint_probs_B_M_C[i]) joint_probs_B_M_C /= K # Now we can compute the entropy. entropy_B = torch.zeros((B,), dtype=torch.float64, device=device) chunk_size = 256 for entropy_b, joint_probs_b_M_C in split_tensors(entropy_B, joint_probs_B_M_C, chunk_size): entropy_b.copy_(torch.sum(-joint_probs_b_M_C * torch.log(joint_probs_b_M_C), dim=(1, 2)), non_blocking=True) return entropy_B
def batch_exact_joint_entropy(probs_B_K_C, prev_joint_probs_M_K, chunk_size, device, out_joint_entropies_B): """This one switches between devices, too.""" for joint_entropies_b, probs_b_K_C in tqdm(torch_utils.split_tensors( out_joint_entropies_B, probs_B_K_C, chunk_size), unit_scale=chunk_size): joint_entropies_b.copy_(joint_entropy_exact.batch( probs_b_K_C.to(device), prev_joint_probs_M_K), non_blocking=True) return joint_entropies_b
def batch_conditional_entropy_B(logits_B_K_C, out_conditional_entropy_B=None): B, K, C = logits_B_K_C.shape if out_conditional_entropy_B is None: out_conditional_entropy_B = torch.empty((B,), dtype=torch.float64) else: assert out_conditional_entropy_B.shape == (B,) for conditional_entropy_b, logits_b_K_C in split_tensors(out_conditional_entropy_B, logits_B_K_C, 8192): logits_b_K_C = logits_b_K_C.double() conditional_entropy_b.copy_( torch.sum(-logits_b_K_C * torch.exp(logits_b_K_C), dim=(1, 2)) / K, non_blocking=True ) return out_conditional_entropy_B
def compute_multi_bald_batch( bayesian_model: nn.Module, available_loader, num_classes, k, b, target_size, initial_percentage, reduce_percentage, device=None, ) -> AcquisitionBatch: result = reduced_eval_consistent_bayesian_model( bayesian_model=bayesian_model, acquisition_function=AcquisitionFunction.bald, num_classes=num_classes, k=k, initial_percentage=initial_percentage, reduce_percentage=reduce_percentage, target_size=target_size, available_loader=available_loader, device=device, ) start_time = time.process_time() subset_split = result.subset_split partial_multi_bald_B = result.scores_B # Now we can compute the conditional entropy conditional_entropies_B = joint_entropy_exact.batch_conditional_entropy_B( result.logits_B_K_C) # We turn the logits into probabilities. probs_B_K_C = result.logits_B_K_C.exp_() # Don't need the result anymore. result = None torch_utils.gc_cuda() # torch_utils.cuda_meminfo() with torch.no_grad(): num_samples_per_ws = 40000 // k num_samples = num_samples_per_ws * k if device.type == "cuda": # KC_memory = k*num_classes*8 sample_MK_memory = num_samples * k * 8 MC_memory = num_samples * num_classes * 8 copy_buffer_memory = 256 * num_samples * num_classes * 8 slack_memory = 2 * 2**30 multi_bald_batch_size = (torch_utils.get_cuda_available_memory() - (sample_MK_memory + copy_buffer_memory + slack_memory)) // MC_memory global compute_multi_bald_bag_multi_bald_batch_size if compute_multi_bald_bag_multi_bald_batch_size != multi_bald_batch_size: compute_multi_bald_bag_multi_bald_batch_size = multi_bald_batch_size print( f"New compute_multi_bald_bag_multi_bald_batch_size = {multi_bald_batch_size}" ) else: multi_bald_batch_size = 16 subset_acquisition_bag = [] global_acquisition_bag = [] acquisition_bag_scores = [] # We use this for early-out in the b==0 case. MIN_SPREAD = 0.1 if b == 0: b = 100 early_out = True else: early_out = False prev_joint_probs_M_K = None prev_samples_M_K = None for i in range(b): torch_utils.gc_cuda() if i > 0: # Compute the joint entropy joint_entropies_B = torch.empty((len(probs_B_K_C), ), dtype=torch.float64) exact_samples = num_classes**i if exact_samples <= num_samples: prev_joint_probs_M_K = joint_entropy_exact.joint_probs_M_K( probs_B_K_C[subset_acquisition_bag[-1]][None].to( device), prev_joint_probs_M_K=prev_joint_probs_M_K, ) # torch_utils.cuda_meminfo() batch_exact_joint_entropy(probs_B_K_C, prev_joint_probs_M_K, multi_bald_batch_size, device, joint_entropies_B) else: if prev_joint_probs_M_K is not None: prev_joint_probs_M_K = None torch_utils.gc_cuda() # Gather new traces for the new subset_acquisition_bag. prev_samples_M_K = joint_entropy_sampling.sample_M_K( probs_B_K_C[subset_acquisition_bag].to(device), S=num_samples_per_ws) # torch_utils.cuda_meminfo() for joint_entropies_b, probs_b_K_C in with_progress_bar( torch_utils.split_tensors(joint_entropies_B, probs_B_K_C, multi_bald_batch_size), unit_scale=multi_bald_batch_size, ): joint_entropies_b.copy_(joint_entropy_sampling.batch( probs_b_K_C.to(device), prev_samples_M_K), non_blocking=True) # torch_utils.cuda_meminfo() prev_samples_M_K = None torch_utils.gc_cuda() partial_multi_bald_B = joint_entropies_B - conditional_entropies_B joint_entropies_B = None # Don't allow reselection partial_multi_bald_B[subset_acquisition_bag] = -math.inf winner_index = partial_multi_bald_B.argmax().item() # Actual MultiBALD is: actual_multi_bald_B = partial_multi_bald_B[ winner_index] - torch.sum( conditional_entropies_B[subset_acquisition_bag]) actual_multi_bald_B = actual_multi_bald_B.item() print(f"Actual MultiBALD: {actual_multi_bald_B}") # If we early out, we don't take the point that triggers the early out. # Only allow early-out after acquiring at least 1 sample. if early_out and i > 1: current_spread = actual_multi_bald_B[ winner_index] - actual_multi_bald_B.median() if current_spread < MIN_SPREAD: print("Early out") break acquisition_bag_scores.append(actual_multi_bald_B) subset_acquisition_bag.append(winner_index) # We need to map the index back to the actual dataset. global_acquisition_bag.append( subset_split.get_dataset_indices([winner_index]).item()) print( f"Acquisition bag: {sorted(global_acquisition_bag)}, num_ack: {i}" ) end_time = time.process_time() time_taken = end_time - start_time print('ack time taken', time_taken) return AcquisitionBatch(global_acquisition_bag, acquisition_bag_scores, None), time_taken
def compute_multi_bald_batch( bayesian_model: nn.Module, available_loader, num_classes, k, # Number of samples to use for monte carlo sampling b, # Acquisition batch size (How many samples do we want to label next) target_size, initial_percentage, reduce_percentage, device=None, ) -> AcquisitionBatch: result = reduced_eval_consistent_bayesian_model( bayesian_model=bayesian_model, acquisition_function=AcquisitionFunction. bald, # This is mutual information num_classes=num_classes, k=k, initial_percentage=initial_percentage, reduce_percentage=reduce_percentage, target_size=target_size, available_loader=available_loader, device=device, ) # Result contains a certain amount of samples with the smallest mutual information subset_split = result.subset_split partial_multi_bald_B = result.scores_B # partial_multi_bald_B contais H(y_1, ..., y_n, y_m) - # E_p(w)[H(y_m|w)], n being the samples already in the aquisition # bag and m being all available samples that are candidates to be # selected into the aquisition bag. For the first sample to be # selcted, this is equivalent to H(y_m) - E_p(w)[H(y_m|w)], i.e. # the mutual information of y_m and the model parameters w. Since # E_p(w)[H(y_1, ..., y_n)] that has to be subtracted to get the # true result of a_BatchBALD is the same for all samples, we can # ignore it to find the best candidate # Now we can compute the conditional entropy conditional_entropies_B = joint_entropy_exact.batch_conditional_entropy_B( result.logits_B_K_C) # conditional_entropies_B = E_p(w)[H(y_i|w)]. After summing # together we get E_p(w)[H(y_1, ..., y_n|w)] which is the right # hand side of Equation 8 to calculate batchBALD # We turn the logits into probabilities. probs_B_K_C = result.logits_B_K_C.exp_() # Don't need the result anymore. result = None torch_utils.gc_cuda() with torch.no_grad(): num_samples_per_ws = 40000 // k # Number of samples used to calculate joint entropy for each sample of the model num_samples = num_samples_per_ws * k # Decide how many samples should be calculated at once when determining the joint entropy if device.type == "cuda": sample_MK_memory = num_samples * k * 8 MC_memory = num_samples * num_classes * 8 copy_buffer_memory = 256 * num_samples * num_classes * 8 slack_memory = 2 * 2**30 multi_bald_batch_size = (torch_utils.get_cuda_available_memory() - (sample_MK_memory + copy_buffer_memory + slack_memory)) // MC_memory global compute_multi_bald_bag_multi_bald_batch_size if compute_multi_bald_bag_multi_bald_batch_size != multi_bald_batch_size: compute_multi_bald_bag_multi_bald_batch_size = multi_bald_batch_size print( f"New compute_multi_bald_bag_multi_bald_batch_size = {multi_bald_batch_size}" ) else: multi_bald_batch_size = 16 subset_acquisition_bag = [ ] # Indices of currently selected samples for next labeling (local indices) global_acquisition_bag = [ ] # Indices of currently selected samples for next labeling (global indices) acquisition_bag_scores = [] # We use this for early-out in the b==0 case. MIN_SPREAD = 0.1 if b == 0: b = 100 early_out = True else: early_out = False prev_joint_probs_M_K = None prev_samples_M_K = None # Iteratively select b samples for labeling and put them in # the acquisition_bag for i in range(b): # Algorithm 1 : Line number 2 torch_utils.gc_cuda() if i > 0: # Only run this starting from the second sample # Compute the joint entropies. Depending on the size # of n (y_1, ..., y_n) we can either solve this # analytically using joint_entropy.exact or via # sampling using joint_entropy.sample # The entropies can be calculated iteratively using information obtained when adding the last joint_entropies_B = torch.empty((len(probs_B_K_C), ), dtype=torch.float64) # If we can, calculate joint entropy analytically, otherwise use sampling exact_samples = num_classes**i if exact_samples <= num_samples: # Use exact joint entropy (no sampling) # P1:n-1? prev_joint_probs_M_K = joint_entropy_exact.joint_probs_M_K( probs_B_K_C[subset_acquisition_bag[-1]][None].to( device), prev_joint_probs_M_K=prev_joint_probs_M_K, ) batch_exact_joint_entropy( probs_B_K_C, # Class probabilities from logits_B_K_C prev_joint_probs_M_K, # multi_bald_batch_size, # Number of samples to compute at once device, # Calculate on GPU or CPU? joint_entropies_B # Filled with the resulting joint entropies ) else: # use sampling to get joint entropy if prev_joint_probs_M_K is not None: prev_joint_probs_M_K = None torch_utils.gc_cuda() # Gather new traces for the new subset_acquisition_bag. prev_samples_M_K = joint_entropy_sampling.sample_M_K( probs_B_K_C[subset_acquisition_bag].to(device), S=num_samples_per_ws) # prev_samples_M_K is the probability of a # certain label assignment configuration for all # samples in the current acquisition_bag i.e. p(y^_1:n-1|w^_j) and therefore P^_{1:n-1} for joint_entropies_b, probs_b_K_C in with_progress_bar( torch_utils.split_tensors(joint_entropies_B, probs_B_K_C, multi_bald_batch_size), unit_scale=multi_bald_batch_size, ): joint_entropies_b.copy_(joint_entropy_sampling.batch( probs_b_K_C.to(device), prev_samples_M_K), non_blocking=True) prev_samples_M_K = None torch_utils.gc_cuda() partial_multi_bald_B = joint_entropies_B - conditional_entropies_B joint_entropies_B = None # Don't allow reselection partial_multi_bald_B[subset_acquisition_bag] = -math.inf # Algorithm 1 : Line 4 winner_index = partial_multi_bald_B.argmax().item() # Actual MultiBALD is: actual_multi_bald_B = partial_multi_bald_B[ winner_index] - torch.sum( conditional_entropies_B[subset_acquisition_bag]) actual_multi_bald_B = actual_multi_bald_B.item() print(f"Actual MultiBALD: {actual_multi_bald_B}") # If we early out, we don't take the point that triggers the early out. # Only allow early-out after acquiring at least 1 sample. if early_out and i > 1: current_spread = actual_multi_bald_B[ winner_index] - actual_multi_bald_B.median() if current_spread < MIN_SPREAD: print("Early out") break acquisition_bag_scores.append(actual_multi_bald_B) # Algorithm 1 : Line 5 subset_acquisition_bag.append(winner_index) # We need to map the index back to the actual dataset. global_acquisition_bag.append( subset_split.get_dataset_indices([winner_index]).item()) print(f"Acquisition bag: {sorted(global_acquisition_bag)}") return AcquisitionBatch(global_acquisition_bag, acquisition_bag_scores, None)