Пример #1
0
    def forward(self, student, teacher):
        with torch.no_grad():
            t_d = pdist(teacher, squared=False)
            mean_td = t_d[t_d > 0].mean()
            t_d = t_d / mean_td

        d = pdist(student, squared=False)
        mean_d = d[d > 0].mean()
        d = d / mean_d

        loss = F.smooth_l1_loss(d, t_d, reduction='elementwise_mean')
        return loss
Пример #2
0
    def forward(self, embeddings, labels):
        with torch.no_grad():
            embeddings = F.normalize(embeddings, dim=1, p=2)
            pos_mask, neg_mask = pos_neg_mask(labels)
            pos_pair_idx = pos_mask.nonzero()
            anchor_idx = pos_pair_idx[:, 0]
            pos_idx = pos_pair_idx[:, 1]

            d = embeddings.size(1)
            dist = (pdist(embeddings, squared=True) +
                    torch.eye(embeddings.size(0),
                              device=embeddings.device,
                              dtype=torch.float32)).sqrt()
            dist = dist.clamp(min=self.cut_off)

            log_weight = ((2.0 - d) * dist.log() - ((d - 3.0) / 2.0) *
                          (1.0 - 0.25 * (dist * dist)).log())
            weight = (log_weight -
                      log_weight.max(dim=1, keepdim=True)[0]).exp()
            weight = weight * (neg_mask *
                               (dist < self.nonzero_loss_cutoff)).float()

            weight = weight + (
                (weight.sum(dim=1, keepdim=True) == 0) * neg_mask).float()
            weight = weight / (weight.sum(dim=1, keepdim=True))
            weight = weight[anchor_idx]
            neg_idx = torch.multinomial(weight, 1).squeeze(1)

        return anchor_idx, pos_idx, neg_idx
Пример #3
0
    def forward(self, student, teacher):
        score_teacher = -1 * self.alpha * pdist(teacher, squared=False).pow(
            self.beta)
        score_student = -1 * self.alpha * pdist(student, squared=False).pow(
            self.beta)

        permute_idx = score_teacher.sort(
            dim=1, descending=True)[1][:, 1:(self.permute_len + 1)]
        ordered_student = torch.gather(score_student, 1, permute_idx)

        log_prob = (ordered_student - torch.stack([
            torch.logsumexp(ordered_student[:, i:], dim=1)
            for i in range(permute_idx.size(1))
        ],
                                                  dim=1)).sum(dim=1)
        loss = (-1 * log_prob).mean()

        return loss
Пример #4
0
def findNumCorrect(embed, labels, K=[1]):
    D = pdist(embed, squared=True)
    knn_inds = D.topk(1 + max(K), dim=1, largest=False, sorted=True)[1][:, 1:]

    """
    Check if, knn_inds contain index of query image.
    """
    assert ((knn_inds == torch.arange(0, len(labels), device=knn_inds.device).unsqueeze(1)).sum().item() == 0)

    selected_labels = labels[knn_inds.contiguous().view(-1)].view_as(knn_inds)
    correct_labels = labels.unsqueeze(1) == selected_labels

    correct_k = (correct_labels[:, :1].sum(dim=1) > 0).float().mean().item()

    return correct_k
Пример #5
0
    def __init__(self,
                 p=2,
                 margin=0.2,
                 sampler=None,
                 reduce=True,
                 size_average=True):
        super().__init__()
        self.p = p
        self.margin = margin

        # update distance function accordingly
        self.sampler = sampler
        self.sampler.dist_func = lambda e: pdist(e, squared=(p == 2))

        self.reduce = reduce
        self.size_average = size_average