def _filter_and_resample_numba(negative_samples, pairs, positives_index, batch_size, voc_size): for i in range(batch_size): positives = positives_index[(pairs[i][0], pairs[i][1])] # inlining the where_in function here results in an internal numba # error which asks to file a bug report resample_idx = where_in(negative_samples[i], positives) # number of new samples needed num_new = len(resample_idx) # number already found of the new samples needed num_found = 0 num_remaining = num_new - num_found while num_remaining: new_samples = np.random.randint(0, voc_size, num_remaining) idx = where_in(new_samples, positives, not_in=True) # write the true negatives found if len(idx): ctr = 0 # numba does not support advanced indexing but the loop # is optimized so it's faster than numpy anyway for j in resample_idx[num_found:num_found + len(idx)]: negative_samples[i, j] = new_samples[ctr] ctr += 1 num_found += len(idx) num_remaining = num_new - num_found
def _filter_and_resample(self, negative_samples: torch.Tensor, slot: int, positive_triples: torch.Tensor): """Filter and resample indices until only negatives have been created. """ pair_str = ["po", "so", "sp"][slot] # holding the positive indices for the respective pair index = self.dataset.index( f"{self.filtering_split}_{pair_str}_to_{SLOT_STR[slot]}") cols = [[P, O], [S, O], [S, P]][slot] pairs = positive_triples[:, cols] for i in range(positive_triples.size(0)): positives = index.get( (pairs[i][0].item(), pairs[i][1].item())).numpy() # indices of samples that have to be sampled again resample_idx = where_in(negative_samples[i].numpy(), positives) # number of new samples needed num_new = len(resample_idx) # number already found of the new samples needed num_found = 0 num_remaining = num_new - num_found while num_remaining: new_samples = self._sample(positive_triples[i, None], slot, num_remaining).view(-1) # indices of the true negatives tn_idx = where_in(new_samples.numpy(), positives, not_in=True) # write the true negatives found if len(tn_idx): negative_samples[ i, resample_idx[num_found:num_found + len(tn_idx)]] = new_samples[tn_idx] num_found += len(tn_idx) num_remaining = num_new - num_found return negative_samples