Esempio n. 1
0
    def forward(self, student, teacher):
        with torch.no_grad():
            t_d = pdist(teacher, squared=False)
            mean_td = t_d[t_d > 0].mean()
            t_d = t_d / mean_td

        d = pdist(student, squared=False)
        mean_d = d[d > 0].mean()
        d = d / mean_d

        loss = F.smooth_l1_loss(d, t_d, reduction='elementwise_mean')
        return loss
Esempio n. 2
0
    def forward(self, student, teacher):
        #print('student size:',student.size())
        #print('teacher size:',teacher.size())
        with torch.no_grad():
            t_d = pdist(teacher, squared=False)
            mean_td = t_d[t_d > 0].mean()
            t_d = t_d / mean_td

        d = pdist(student, squared=False)
        mean_d = d[d > 0].mean()
        d = d / mean_d

        #loss = F.smooth_l1_loss(d, t_d, reduction='elementwise_mean')
        loss = F.smooth_l1_loss(d, t_d, reduction='mean')
        return loss
    def forward(self, embeddings, labels):
        with torch.no_grad():
            embeddings = F.normalize(embeddings, dim=1, p=2)
            pos_mask, neg_mask = pos_neg_mask(labels)
            pos_pair_idx = pos_mask.nonzero()
            anchor_idx = pos_pair_idx[:, 0]
            pos_idx = pos_pair_idx[:, 1]

            d = embeddings.size(1)
            dist = (
                pdist(embeddings, squared=True)
                + torch.eye(
                    embeddings.size(0), device=embeddings.device, dtype=torch.float32
                )
            ).sqrt()
            dist = dist.clamp(min=self.cut_off)

            log_weight = (2.0 - d) * dist.log() - ((d - 3.0) / 2.0) * (
                1.0 - 0.25 * (dist * dist)
            ).log()
            weight = (log_weight - log_weight.max(dim=1, keepdim=True)[0]).exp()
            weight = weight * (neg_mask * (dist < self.nonzero_loss_cutoff)).float()

            weight = (
                weight + ((weight.sum(dim=1, keepdim=True) == 0) * neg_mask).float()
            )
            weight = weight / (weight.sum(dim=1, keepdim=True))
            weight = weight[anchor_idx]
            neg_idx = torch.multinomial(weight, 1).squeeze(1)

        return anchor_idx, pos_idx, neg_idx
Esempio n. 4
0
    def forward(self, embeddings):
        dist_mat = pdist(embeddings)
        pdist_mat = dist_mat[~torch.eye(
            dist_mat.shape[0],
            dtype=torch.bool,
            device=dist_mat.device,
        )]
        dist_mat = dist_mat.view(-1)

        mean = dist_mat.mean().detach()
        std = dist_mat.std().detach()

        if not self.init:
            self.initialize_statistics(mean, std)
        else:
            self.momented_mean = (
                1 - self.momentum) * mean + self.momentum * self.momented_mean
            self.momented_std = (
                1 - self.momentum) * std + self.momentum * self.momented_std

        normalized_dist = (pdist_mat - self.momented_mean) / self.momented_std
        difference = (normalized_dist[None] -
                      self.levels[:, None]).abs().min(dim=0)[0]
        loss = difference.mean()

        return loss
Esempio n. 5
0
    def forward(self, student, teacher):
        with torch.no_grad():
            score_teacher = -1 * self.alpha * pdist(
                teacher, squared=False).pow(self.beta)
            permute_idx = score_teacher.sort(
                dim=1, descending=True)[1][:, 1:(self.permute_len + 1)]
        score_student = -1 * self.alpha * pdist(student, squared=False).pow(
            self.beta)
        ordered_student = torch.gather(score_student, 1, permute_idx)

        log_prob = (ordered_student - torch.stack([
            torch.logsumexp(ordered_student[:, i:], dim=1)
            for i in range(permute_idx.size(1))
        ],
                                                  dim=1)).sum(dim=1)
        loss = (-1 * log_prob).mean()

        return loss
Esempio n. 6
0
    def __init__(self,
                 margin=0.2,
                 sampler=None,
                 reduce=True,
                 size_average=True):
        super().__init__()
        self.margin = margin

        self.sampler = sampler
        self.sampler.dist_func = lambda e: pdist(e, squared=(p == 2))

        self.reduce = reduce
        self.size_average = size_average
Esempio n. 7
0
    def forward(self, x):
        feat = self.base(x)
        feat = feat.view(x.size(0), -1)
        embedding = self.linear(feat)

        if self.training:
            # Please check "Learning without L2 Norm" in Section 3.1
            dist_mat = pdist(embedding)
            mean_d = dist_mat[
                ~torch.eye(dist_mat.shape[0], dtype=torch.bool, device=dist_mat.device)
            ].mean()
            return embedding / mean_d

        return embedding
Esempio n. 8
0
    def __init__(self,
                 p=2,
                 margin=0.2,
                 sampler=None,
                 reduce=True,
                 size_average=True):
        super().__init__()
        self.p = p
        self.margin = margin

        # update distance function accordingly
        self.sampler = sampler
        self.sampler.dist_func = lambda e: pdist(e, squared=(p == 2))

        self.reduce = reduce
        self.size_average = size_average