class MarginLoss(Function): r"""Margin based loss. Parameters ---------- margin : float Margin between positive and negative pairs. nu : float Regularization parameter for beta. Inputs: - anchors: sampled anchor embeddings. - positives: sampled positive embeddings. - negatives: sampled negative embeddings. - beta_in: class-specific betas. - a_indices: indices of anchors. Used to get class-specific beta. Outputs: - Loss. """ def __init__(self, margin=0.2, nu=0.0, weight=None, batch_axis=0, **kwargs): super(MarginLoss, self).__init__() self.margin = margin self.nu = nu self.pdist = PairwiseDistance(2) self.weight = weight def forward(self, anchors, positives, negatives, beta_in, a_indices=None): if a_indices is not None: #确认beta_in是否需要是variable # Jointly train class-specific beta. beta = beta_in.index_select(0, a_indices) beta_reg_loss = torch.sum(beta) * self.nu else: # Use a constant beta. beta = beta_in beta_reg_loss = 0.0 d_p = self.pdist.forward(anchors, positives) d_n = self.pdist.forward(anchors, negatives) # d_ap = F.sqrt(F.sum(F.square(positives - anchors), axis=1) + 1e-8) # d_an = F.sqrt(F.sum(F.square(negatives - anchors), axis=1) + 1e-8) pos_loss = torch.clamp(self.margin + d_p - beta, min=0.0) neg_loss = torch.clamp(self.margin - d_n + beta, min=0.0) pair_cnt = float( np.sum((pos_loss.cpu().data.numpy() > 0.0) + (neg_loss.cpu().data.numpy() > 0.0))) # Normalize based on the number of pairs. loss = (torch.sum(torch.pow(pos_loss, 2) + torch.pow(neg_loss, 2)) + beta_reg_loss) / pair_cnt if self.weight: loss = loss * self.weight return loss
class TripletSoftMarginLoss(Function): def __init__(self): super(TripletSoftMarginLoss).__init__() self.pdist = PairwiseDistance(2) self.activion = torch.nn.Softplus() def forward(self, anchor, positive, negative): d_p = self.pdist.forward(anchor, positive) d_n = self.pdist.forward(anchor, negative) dist_hinge = torch.clamp(self.activion(d_p) + d_p - d_n, min=0.0) loss = torch.mean(dist_hinge) return loss
class TripletMarginLoss(Function): """Triplet loss function. """ def __init__(self, margin): super(TripletMarginLoss, self).__init__() self.margin = margin self.pdist = PairwiseDistance(2) # norm 2 def forward(self, anchor, positive, negative): d_p = self.pdist.forward(anchor, positive) d_n = self.pdist.forward(anchor, negative) dist_hinge = torch.clamp(self.margin + d_p - d_n, min=0.0) loss = torch.mean(dist_hinge) return loss
def train_epoch_some(train_loader, model, loss_fn, optimizer, cuda, log_interval, metrics): for metric in metrics: metric.reset() model.train() losses = [] total_loss = 0 pbar = tqdm(enumerate(train_loader)) labels, distances = [], [] l2_dist = PairwiseDistance(2) for batch_idx, (data_a, data_p, data_n, label_p, label_n) in pbar: data_a, data_p, data_n = data_a.cuda(), data_p.cuda(), data_n.cuda() data_a, data_p, data_n = Variable(data_a), Variable(data_p), \ Variable(data_n) # compute output out_a, out_p, out_n = model(data_a), model(data_p), model(data_n) # Choose the hard negatives d_p = l2_dist.forward(out_a, out_p) d_n = l2_dist.forward(out_a, out_n) all = (d_n - d_p < args.margin).cpu().data.numpy().flatten() hard_triplets = np.where(all == 1) if len(hard_triplets[0]) == 0: continue out_selected_a = Variable( torch.from_numpy(out_a.cpu().data.numpy()[hard_triplets]).cuda()) out_selected_p = Variable( torch.from_numpy(out_p.cpu().data.numpy()[hard_triplets]).cuda()) out_selected_n = Variable( torch.from_numpy(out_n.cpu().data.numpy()[hard_triplets]).cuda()) selected_data_a = Variable( torch.from_numpy(data_a.cpu().data.numpy()[hard_triplets]).cuda()) selected_data_p = Variable( torch.from_numpy(data_p.cpu().data.numpy()[hard_triplets]).cuda()) selected_data_n = Variable( torch.from_numpy(data_n.cpu().data.numpy()[hard_triplets]).cuda()) selected_label_p = torch.from_numpy( label_p.cpu().numpy()[hard_triplets]) selected_label_n = torch.from_numpy( label_n.cpu().numpy()[hard_triplets]) triplet_loss = loss_fn.forward(out_selected_a, out_selected_p, out_selected_n) cls_a = model.forward_classifier(selected_data_a) cls_p = model.forward_classifier(selected_data_p) cls_n = model.forward_classifier(selected_data_n) cls_a = model.forward_classifier(selected_data_a) cls_p = model.forward_classifier(selected_data_p) cls_n = model.forward_classifier(selected_data_n) criterion = nn.CrossEntropyLoss() predicted_labels = torch.cat([cls_a, cls_p, cls_n]) true_labels = torch.cat([ Variable(selected_label_p.cuda()), Variable(selected_label_p.cuda()), Variable(selected_label_n.cuda()) ]) cross_entropy_loss = criterion(predicted_labels.cuda(), true_labels.cuda()) loss = cross_entropy_loss + triplet_loss # compute gradient and update weights optimizer.zero_grad() loss.backward() optimizer.step() # update the optimizer learning rate adjust_learning_rate(optimizer)
def test(test_loader, model, epoch): # switch to evaluate mode model.eval() labels, distances = [], [] pbar = tqdm(enumerate(test_loader)) for batch_idx, (data_a, data_p, label) in pbar: if args.cuda: data_a, data_p = data_a.cuda(), data_p.cuda() data_a, data_p, label = Variable(data_a, volatile=True), \ Variable(data_p, volatile=True), Variable(label) # compute output out_a, out_p = model(data_a), model(data_p) dists = l2_dist.forward(out_a,out_p)#torch.sqrt(torch.sum((out_a - out_p) ** 2, 1)) # euclidean distance distances.append(dists.data.cpu().numpy()) labels.append(label.data.cpu().numpy()) if batch_idx % args.log_interval == 0: pbar.set_description('Test Epoch: {} [{}/{} ({:.0f}%)]'.format( epoch, batch_idx * len(data_a), len(test_loader.dataset), 100. * batch_idx / len(test_loader))) labels = np.array([sublabel for label in labels for sublabel in label]) distances = np.array([subdist[0] for dist in distances for subdist in dist]) tpr, fpr, accuracy, val, val_std, far = evaluate(distances,labels) print('\33[91mTest set: Accuracy: {:.8f}\n\33[0m'.format(np.mean(accuracy))) logger.log_value('Test Accuracy', np.mean(accuracy))