def check_triplets_are_hardest( ids_anchor: List[int], ids_pos: List[int], ids_neg: List[int], labels: List[int], distmat: Tensor, ) -> None: """ Args: ids_anchor: anchor indexes of selected triplets ids_pos: positive indexes of selected triplets ids_neg: negative indexes of selected triplets labels: labels of the samples in the batch distmat: distances between features """ ids_all = set(range(len(labels))) for i_a, i_p, i_n in zip(ids_anchor, ids_pos, ids_neg): ids_label = set(find_value_ids(it=labels, value=labels[i_a])) ids_pos_cur = np.array(list(ids_label - {i_a}), int) ids_neg_cur = np.array(list(ids_all - ids_label), int) assert torch.isclose(distmat[i_a, ids_pos_cur].max(), distmat[i_a, i_p]) assert torch.isclose(distmat[i_a, ids_neg_cur].min(), distmat[i_a, i_n])
def _sample_from_distmat(distmat: Tensor, labels: List[int]) -> TTripletsIds: """ This method samples the hardest triplets based on the given distances matrix. It chooses each sample in the batch as an anchor and then finds the harderst positive and negative pair. Args: distmat: matrix of distances between the features labels: labels of the samples in the batch Returns: the batch of triplets in the order below: (anchor, positive, negative) """ ids_all = set(range(len(labels))) ids_anchor, ids_pos, ids_neg = [], [], [] for i_anch, label in enumerate(labels): ids_label = set(find_value_ids(it=labels, value=label)) ids_pos_cur = np.array(list(ids_label - {i_anch}), int) ids_neg_cur = np.array(list(ids_all - ids_label), int) i_pos = ids_pos_cur[distmat[i_anch, ids_pos_cur].argmax()] i_neg = ids_neg_cur[distmat[i_anch, ids_neg_cur].argmin()] ids_anchor.append(i_anch) ids_pos.append(i_pos) ids_neg.append(i_neg) return ids_anchor, ids_pos, ids_neg
def _sample(self, *_: Tensor, labels: List[int]) -> TTripletsIds: """ Args: labels: labels of the samples in the batch *_: note, that we ignore features argument Returns: indices of triplets """ num_labels = len(labels) triplets = [] for label in set(labels): ids_pos_cur = set(find_value_ids(labels, label)) ids_neg_cur = set(range(num_labels)) - ids_pos_cur pos_pairs = list(combinations(ids_pos_cur, r=2)) tri = [(a, p, n) for (a, p), n in product(pos_pairs, ids_neg_cur)] triplets.extend(tri) triplets = sample(triplets, min(len(triplets), self._max_out_triplets)) ids_anchor, ids_pos, ids_neg = zip(*triplets) return list(ids_anchor), list(ids_pos), list(ids_neg)
def _get_labels_mask(labels: List[int]) -> Tensor: """ Generate matrix of bool of shape (n_unique_labels, batch_size), where n_unique_labels is a number of unique labels in the batch; matrix[i, j] is True if j-th element of the batch relates to i-th class and False otherwise. Args: labels: labels of the batch, shape (batch_size) Returns: matrix of indices of classes in batch """ unique_labels = sorted(np.unique(labels)) labels_number = len(unique_labels) labels_mask = torch.zeros(size=(labels_number, len(labels))) for label_idx, label in enumerate(unique_labels): label_indices = find_value_ids(labels, label) labels_mask[label_idx][label_indices] = 1 return labels_mask.type(torch.bool)
def __iter__(self) -> Iterator[int]: """ Returns: indeces for sampling dataset elems during an epoch """ inds = [] for cls_id in random.sample(self._classes, self._num_epoch_classes): all_cls_inds = find_value_ids(self._labels, cls_id) # we've checked in __init__ that this value must be > 1 num_samples_exists = len(all_cls_inds) if num_samples_exists < self._k: selected_inds = random.sample( all_cls_inds, k=num_samples_exists) + random.choices( all_cls_inds, k=self._k - num_samples_exists) else: selected_inds = random.sample(all_cls_inds, k=self._k) inds.extend(selected_inds) return iter(inds)