Exemple #1
0
    def sample(self, features: Tensor, labels: TLabels) -> TTriplets:
        """
        Args:
            features: has the shape of [batch_size, feature_size]
            labels: labels of the samples in the batch

        Returns:
            the batch of the triplets in the order below:
            (anchor, positive, negative)
        """
        # Convert labels to list
        labels = convert_labels2list(labels)
        self._check_input_labels(labels=labels)

        ids_anchor, ids_pos, ids_neg = self._sample(features, labels=labels)

        return features[ids_anchor], features[ids_pos], features[ids_neg]
Exemple #2
0
    def sample(self, features: Tensor, labels: TLabels) -> TTriplets:
        """
        This method samples the hardest triplets in the batch.

        Args:
            features: tensor of shape (batch_size; embed_dim) that contains
                k samples for each of p classes
            labels: labels of the batch, list or tensor of size (batch_size,)

        Returns:
            p triplets of (mean_vector, positive, negative_mean_vector)
        """
        # Convert labels to list
        labels = convert_labels2list(labels)
        self._check_input_labels(labels)

        # Get matrix of indices of labels in batch
        labels_mask = self._get_labels_mask(labels)
        p = labels_mask.shape[0]

        embed_dim = features.shape[-1]
        # Reshape embeddings to groups of (p, k, embed_dim) ones,
        # each i-th group contains embeddings of i-th class.
        features = features.repeat((p, 1, 1))
        features = features[labels_mask].view((p, -1, embed_dim))

        # Count mean vectors for each class in batch
        mean_vectors = features.mean(1)

        d_intra = self._count_intra_class_distances(features, mean_vectors)
        # Count the distances to the sample farthest from mean vector
        # for each class.
        pos_indices = d_intra.max(1).indices
        # Count matrix of distances from mean vectors to each other
        d_inter = self._count_inter_class_distances(mean_vectors)
        # For each class mean vector get the closest mean vector
        d_inter = self._fill_diagonal(d_inter, float("inf"))
        neg_indices = d_inter.min(1).indices
        positives = torch.stack(
            [features[idx][pos_idx] for idx, pos_idx in enumerate(pos_indices)]
        )

        return mean_vectors, positives, mean_vectors[neg_indices]
Exemple #3
0
    def forward(self, features: Tensor, labels: Union[Tensor,
                                                      List[int]]) -> Tensor:
        """
        Args:
            features: features with the shape of [batch_size, features_dim]
            labels: labels of samples having batch_size elements

        Returns: loss value

        """
        labels_list = convert_labels2list(labels)

        (
            features_anchor,
            features_positive,
            features_negative,
        ) = self._sampler_inbatch.sample(features=features, labels=labels_list)

        loss = self._triplet_margin_loss(
            anchor=features_anchor,
            positive=features_positive,
            negative=features_negative,
        )
        return loss