def train_forward(self, x): features, adverbs, actions = x[0], x[1], x[2] neg_adverbs, neg_actions = x[3], x[4] action_embedding = self.action_embedder(actions) neg_action_embedding = self.action_embedder(neg_actions) if self.transformer: video_embedding, attention_weights = self.video_embedder( features, action_embedding) else: video_embedding = self.video_embedder(features) attention_weights = None pos_modifiers = torch.stack( [self.action_modifiers[adv.item()] for adv in adverbs]) positive = self.apply_modifiers(pos_modifiers, action_embedding) negative_act = self.apply_modifiers(pos_modifiers, neg_action_embedding) neg_modifiers = torch.stack( [self.action_modifiers[adv.item()] for adv in neg_adverbs]) negative_adv = self.apply_modifiers(neg_modifiers, action_embedding) loss_triplet_act = F.triplet_margin_loss(video_embedding, positive, negative_act, margin=self.margin) loss_triplet_adv = F.triplet_margin_loss(video_embedding, positive, negative_adv, margin=self.margin) loss = [loss_triplet_act, loss_triplet_adv] return loss, None, attention_weights, video_embedding
def eval_pair_embed_losses(self, args: CommandlineArgs, img_feat, img_hidden_emb, attr_labels, obj_labels, neg_attr_labels, neg_obj_labels, nll_loss_funcs): device = args.device with ns_profiling_label('labels_to_embeddings'): h_A_pos, h_O_pos, g_hidden_pos, g_img_pos = self.labels_to_embeddings( attr_labels, obj_labels) _, _, g_hidden_neg, g_img_neg = self.labels_to_embeddings( neg_attr_labels, neg_obj_labels) tloss_g_imgfeat = torch.tensor(0.).to(device) if args.train.lambda_feat > 0: with ns_profiling_label('tloss_g_imgfeat'): tloss_g_imgfeat = triplet_margin_loss( img_feat, g_img_pos, g_img_neg, margin=args.train.triplet_loss_margin) tloss_g_hidden = torch.tensor(0.).to(device) if args.train.lambda_ao_emb > 0: with ns_profiling_label('tloss_g_hidden'): tloss_g_hidden = triplet_margin_loss( img_hidden_emb, g_hidden_pos, g_hidden_neg, margin=args.train.triplet_loss_margin) # Loss_invert terms loss_inv_core = torch.tensor(0.).to(device) if args.train.lambda_aux_disjoint > 0: # check hp name with ns_profiling_label('loss_inv_core'): loss_inv_core = nll_sum_loss( self.attr_inv_core_logits(h_A_pos), self.obj_inv_core_logits(h_O_pos), attr_labels, obj_labels, nll_loss_funcs) loss_inv_g_imgfeat = torch.tensor(0.).to(device) if args.train.lambda_aux_img > 0: # check hp name with ns_profiling_label('loss_inv_g_imgfeat'): loss_inv_g_imgfeat = nll_sum_loss( self.attr_inv_g_imgfeat_logits(g_img_pos), self.obj_inv_g_imgfeat_logits(g_img_pos), attr_labels, obj_labels, nll_loss_funcs) loss_inv_g_hidden = torch.tensor(0.).to(device) if args.train.lambda_aux > 0: # check hp name with ns_profiling_label('loss_inv_g_hidden'): loss_inv_g_hidden = nll_sum_loss( self.attr_inv_g_hidden_logits(g_hidden_pos), self.obj_inv_g_hidden_logits(g_hidden_pos), attr_labels, obj_labels, nll_loss_funcs) return tloss_g_hidden, tloss_g_imgfeat, loss_inv_core, loss_inv_g_hidden, loss_inv_g_imgfeat
def calculate_triplet(self, cross_feature, pos_feature, neg_feature, count, cross_length): if self.loss_ratio: logits = F.triplet_margin_loss(cross_feature, pos_feature, neg_feature) else: logits = F.triplet_margin_loss(cross_feature, pos_feature, neg_feature) ipdb.set_trace() return logits
def forward(self, x): x = x.view(x.size(0), -1) x = self.fc1(x) x = self.relu1(x) x = self.drop1(x) x = self.fc2(x) x = self.relu2(x) x = self.drop2(x) x = self.fc3(x) l2_norm = x.div(torch.norm(x, p=2, dim=1).repeat(x.size(1), 1).t()) warped_l2_norm = self.masklayer(l2_norm) anchors = Variable(torch.zeros(128)).cuda().expand([3080, 128]) positives = tile(warped_l2_norm[0:55], 0, 56) negatives = warped_l2_norm[55:].expand(55, 56, 128).flatten().view( 56 * 55, 128) assert x.size(0) == 111 loss = F.triplet_margin_loss(anchors, positives, negatives, p=2, margin=0.2) return loss, warped_l2_norm
def forward(self, inputs, targets, perturbed_feature): """ Args: - inputs: feature matrix with shape (batch_size, feat_dim) - targets: ground truth labels with shape (num_classes) """ n = inputs.size(0) # Compute pairwise distance, replace by the official when merged dist = torch.pow(inputs, 2).sum(dim=1, keepdim=True).expand(n, n) dist = dist + dist.t() dist.addmm_(1, -2, inputs, inputs.t()) dist = dist.clamp(min=1e-12).sqrt() # for numerical stability # For each anchor, find the hardest positive and negative mat_similarity = targets.expand(n, n).eq(targets.expand(n, n).t()).float() sorted_mat_distance, positive_indices = torch.sort( dist + (-100000.0) * (1 - mat_similarity), dim=1, descending=True) hard_p = sorted_mat_distance[:, 0] hard_p_indice = positive_indices[:, 0] sorted_mat_distance, negative_indices = torch.sort( dist + 100000.0 * mat_similarity, dim=1, descending=False) hard_n = sorted_mat_distance[:, 0] hard_n_indice = negative_indices[:, 0] hard_p_feature = inputs[hard_p_indice, :] hard_n_feature = inputs[hard_n_indice, :] loss = 10 * F.triplet_margin_loss(perturbed_feature, hard_n_feature, hard_p_feature, 0.5) return loss
def train(data_loader, model, optimizer, args, writer): for batch in data_loader: (images, labels, idx, pimages, plabels, pidx, nimages, nlabels, nidx) = map( lambda a: a.to(args.device), batch ) optimizer.zero_grad() p_labels, a, rsample = model(images, idx) _, b, _ = model(pimages, pidx) _, c, _ = model(nimages, nidx) loss_margin = F.triplet_margin_loss(a, b, c, margin=0.3, swap=True) # loss_class = F.cross_entropy(p_labels, labels) bs = labels.shape[0] k = p_labels.shape[1] loss_class = ( kornia.losses.focal_loss( p_labels.view(bs, k, 1, 1).cpu(), labels.view(bs, 1, 1).cpu(), 0.5 ) .to(args.device) .mean() ) loss = loss_margin * 10 + loss_class loss.backward() # Logs writer.add_scalar("loss/train", loss.item(), args.steps) optimizer.step() args.steps += 1
def forward(self): a = torch.randn(3, 2) b = torch.rand(3, 2) c = torch.rand(3) log_probs = torch.randn(50, 16, 20).log_softmax(2).detach() targets = torch.randint(1, 20, (16, 30), dtype=torch.long) input_lengths = torch.full((16, ), 50, dtype=torch.long) target_lengths = torch.randint(10, 30, (16, ), dtype=torch.long) return len( F.binary_cross_entropy(torch.sigmoid(a), b), F.binary_cross_entropy_with_logits(torch.sigmoid(a), b), F.poisson_nll_loss(a, b), F.cosine_embedding_loss(a, b, c), F.cross_entropy(a, b), F.ctc_loss(log_probs, targets, input_lengths, target_lengths), # F.gaussian_nll_loss(a, b, torch.ones(5, 1)), # ENTER is not supported in mobile module F.hinge_embedding_loss(a, b), F.kl_div(a, b), F.l1_loss(a, b), F.mse_loss(a, b), F.margin_ranking_loss(c, c, c), F.multilabel_margin_loss(self.x, self.y), F.multilabel_soft_margin_loss(self.x, self.y), F.multi_margin_loss(self.x, torch.tensor([3])), F.nll_loss(a, torch.tensor([1, 0, 1])), F.huber_loss(a, b), F.smooth_l1_loss(a, b), F.soft_margin_loss(a, b), F.triplet_margin_loss(a, b, -b), # F.triplet_margin_with_distance_loss(a, b, -b), # can't take variable number of arguments )
def train(train_loader, model, optimizer, epoch, logger): # switch to train mode model.train() pbar = tqdm(enumerate(train_loader)) for batch_idx, (data_a, data_p, data_n) in pbar: if args.cuda: data_a, data_p, data_n = data_a.cuda(), data_p.cuda(), data_n.cuda() data_a, data_p, data_n = Variable(data_a), Variable(data_p), Variable(data_n) out_a, out_p, out_n = model(data_a), model(data_p), model(data_n) #hardnet loss loss = F.triplet_margin_loss(out_p, out_a, out_n, margin=args.margin, swap=args.anchorswap) optimizer.zero_grad() loss.backward() optimizer.step() adjust_learning_rate(optimizer, args) if(logger!=None): logger.log_value('loss', loss.data[0]).step() if batch_idx % args.log_interval == 0: pbar.set_description( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch, batch_idx * len(data_a), len(train_loader.dataset), 100. * batch_idx / len(train_loader), loss.data[0])) torch.save({'epoch': epoch + 1, 'state_dict': model.state_dict()}, '{}/checkpoint_{}.pth'.format(LOG_DIR, epoch))
def adaptive_triplet_loss(anchor, positives, negatives, **kwargs): margin = kwargs.pop('margin', 0.25) p = kwargs.pop('p', 2) eps = kwargs.pop('eps', 1e-6) swap = kwargs.pop('swap', True) adaptive_factor = kwargs.pop('adaptive_factor', False) if kwargs: raise TypeError('Unexpected **kwargs: %r' % kwargs) tt_loss = None cpt = 0 for positive in positives: for negative in negatives: c_loss = func.triplet_margin_loss(anchor, positive, negative, margin=margin, eps=eps, p=p, swap=swap) if tt_loss is None: tt_loss = c_loss else: tt_loss += c_loss if adaptive_factor: cpt += 1 if c_loss.item() > 0 else 0 else: cpt += 1 tt_loss /= cpt if cpt else 1 return tt_loss
def trainer_Triplet(model, epoch=20000): lr = 0.0002 optimizer = optim.Adam(model.parameters(), lr=lr) avg_loss = 0 criterion = BatchHardTripletLoss() for index in range(epoch): if index % 10000 == 0: lr /= 10 optimizer = optim.Adam(model.parameters(), lr=lr) anchor, pos, neg = dataloader.get_triplet_batch() anchor = anchor.to(device) pos = pos.to(device) neg = neg.to(device) anchor_features = model(anchor) pos_features = model(pos) neg_features = model(neg) #loss = criterion(embedding, targets) loss = F.triplet_margin_loss(anchor_features, pos_features, neg_features) optimizer.zero_grad() loss.backward() optimizer.step() avg_loss += loss.item() if index % 2000 == 0 and index != 0: path = os.path.join(model_path, "{}_BH_VeRI_{}.pt".format(model_name, index)) torch.save(model.state_dict(), path) if index % 50 == 0 and index != 0: print('batch {} avgloss {}'.format(index, avg_loss / 50)) avg_loss = 0
def forward(self, anchor_sentences: List[str], positive_sentences: List[str], negative_sentences: List[str], loss_kwargs: Dict[str, Any] = None) -> torch.FloatTensor: """ Returns triplet margin loss so that model can be fine-tuned for ranking. # Parameters anchor_sentences: List[str] List of "anchor" sentences. positive_sentences: List[str] List of sentences that should be close to corresponding "anchor" sentences. negative_sentences: List[str] List of sentences that should be distant to corresponding "anchor" sentences. loss_kwargs: Dict[str, Any] Optional dictionary of arguments to be passed to the `triplet_margin_loss` function. See PyTorch docs for more details: https://pytorch.org/docs/master/nn.functional.html#triplet-margin-loss # Returns loss : torch.FloatTensor Scalar loss value. """ assert len(anchor_sentences) == len(positive_sentences) == len( negative_sentences) loss_kwargs = loss_kwargs or {} anchor_embeddings = self._encode(anchor_sentences) positive_embeddings = self._encode(positive_sentences) negative_embeddings = self._encode(negative_sentences) loss = F.triplet_margin_loss(anchor_embeddings, positive_embeddings, negative_embeddings, **loss_kwargs) return loss
def train_forward(self, img, obj_label, pos_op_label, neg_obj, neg_op_label): anchor = self.image_embedder(img) obj_emb = self.obj_embedder( torch.tensor(obj_label)) # , dtype=torch.long)) pos_op = self.attr_ops[pos_op_label] positive = self.apply_op(obj_emb, pos_op) neg_obj = torch.tensor(self.objects.index(neg_obj)) neg_obj_emb = self.obj_embedder(neg_obj) #, dtype=torch.long)) neg_op_label = torch.tensor(self.operators.index(neg_op_label)) neg_op = self.attr_ops[neg_op_label] negative = self.apply_op(neg_obj_emb, neg_op) # ============================================================================= # print(pos_op) # print(positive) # print() # print(neg_op) # print(negative) # ============================================================================= loss = F.triplet_margin_loss(anchor, positive, negative, margin=1.5) return loss
def test_triplet_margin_loss(self): inp1 = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) inp2 = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) inp3 = torch.randn(32, 128, device='cuda', dtype=self.dtype, requires_grad=True) output = F.triplet_margin_loss(inp1, inp2, inp3, margin=1.0, p=2, eps=1e-06, swap=False, size_average=None, reduce=None, reduction='mean')
def get_loss1_loss2(self, img_loader, data_loader_wo_img): self.model.eval() loss1 = [] loss2 = [] with torch.no_grad(): if self.dataset == 'age': score_dict = self.get_all_rank_scores(img_loader) for img_id1, img_id2, label, _, _, _ in data_loader_wo_img: img_id1, img_id2 = img_id1.to(self.device), img_id2.to(self.device) label = label.to(self.device) for i1, i2, l in zip(img_id1, img_id2, label): score_diff = score_dict[int(i1)] - score_dict[int(i2)] loss1.append((score_diff - l) ** 2 * 0.5) loss2.append((score_diff + l) ** 2 * 0.5) elif self.dataset == 'lfw': score_dict = self.get_all_rank_scores_lfw(img_loader) for img_id1, img_id2, label, attr_id, _, _ in data_loader_wo_img: img_id1, img_id2 = img_id1.to(self.device), img_id2.to(self.device) label, attr_id = label.to(self.device), attr_id.to(self.device) for i1, i2, l, a in zip(img_id1, img_id2, label, attr_id): score_diff = score_dict[int(i1)][int(a)].item() - score_dict[int(i2)][int(a)].item() loss1.append((score_diff - l) ** 2 * 0.5) loss2.append((score_diff + l) ** 2 * 0.5) elif self.dataset == 'food': score_dict = self.get_all_rank_scores_lfw(img_loader) for anc, pos, neg, _ in data_loader_wo_img: anc, pos, neg = anc.to(self.device), pos.to(self.device), neg.to(self.device) for a, p, n in zip(anc, pos, neg): output1 = score_dict[int(a)].view(1,-1) output2 = score_dict[int(p)].view(1,-1) output3 = score_dict[int(n)].view(1,-1) l1 = triplet_margin_loss(output1, output2, output3, margin=1.0, p=2) l2 = triplet_margin_loss(output1, output3, output2, margin=1.0, p=2) loss1.append(l1.item()) loss2.append(l2.item()) return np.array(loss1, dtype=float), np.array(loss2, dtype=float)
def forward(self, embeddings, labels): embeddings_sqr = embeddings.pow(2).sum(dim=1) distance_matrix = torch.addmm(1, embeddings_sqr + embeddings_sqr.t(), -2, embeddings, embeddings.t()).cpu() labels = labels.cpu().numpy() triplets = [] for label in set(labels): label_mask = (labels == label) label_indices = np.where(label_mask)[0] if len(label_indices) < 2: continue negative_indices = np.where(np.logical_not(label_mask))[0] anchor_positive_pairs = np.array( list(combinations(label_indices, 2))) ap_distances = distance_matrix[anchor_positive_pairs[:, 0], anchor_positive_pairs[:, 1]] for (anchor_idx, positive_idx), ap_distance in zip(anchor_positive_pairs, ap_distances): loss_values = ap_distance - distance_matrix[ anchor_idx, negative_indices] + self.margin if len(loss_values) > 0: loss_values = loss_values.detach().cpu().numpy() if self.hard: hard_negative_idx = np.argmax(loss_values) if loss_values[hard_negative_idx] > 0: triplets.append([ anchor_idx, positive_idx, negative_indices[hard_negative_idx] ]) else: semihard_negative_indices = np.where( np.logical_and(loss_values < self.margin, loss_values > 0))[0] if len(semihard_negative_indices) > 0: semihard_negative_idx = np.random.choice( semihard_negative_indices) triplets.append([ anchor_idx, positive_idx, negative_indices[semihard_negative_idx] ]) if len(triplets) > 0: triplets = np.array(triplets) return F.triplet_margin_loss(embeddings[triplets[:, 0]], embeddings[triplets[:, 1]], embeddings[triplets[:, 2]], margin=self.margin) else: return torch.tensor(0, dtype=torch.float32, device=embeddings.device)
def forward(self, embedded_a, embedded_p, embedded_n): return F.triplet_margin_loss(embedded_a, embedded_p, embedded_n, margin=self.margin, p=2, eps=1e-6, swap=False)
def criterion_tripletmargin(input, positive, negative, margin=0.1, size_average=True): loss = F.triplet_margin_loss(input, positive, negative, margin=margin) if not size_average: return loss * input.size(0) return loss
def triplet_margin_loss(anchor, positives, negatives, margin=0.25, p=2, eps=1e-6, factor=1, swap=False): return factor * func.triplet_margin_loss( anchor, positives, negatives, margin=margin, p=p, eps=eps, swap=swap)
def loss_fn(self, prediction, target=None): """ Triplet loss. :param prediction: :param target: :return: """ return F.triplet_margin_loss(anchor=prediction[0], positive=prediction[1], negative=prediction[2], margin=self.loss_margin, p=self.loss_p)
def forward(self, embeddings, target): triplets = self.triplet_selector.get_triplets(embeddings, target) if embeddings.is_cuda: triplets = triplets.cuda() anchor = embeddings[triplets[:, 0]] positive = embeddings[triplets[:, 1]] negative = embeddings[triplets[:, 2]] return F.triplet_margin_loss(anchor, positive, negative, self.margin, self.p, self.eps, self.swap)
class TripletMarginLoss(torch.nn.Module): def __init__(self, margin=1.0, p=2, eps=1e-6, swap=True): super(TripletMarginLoss, self).__init__() self.margin = margin self.p = p self.eps = eps self.swap = swap def forward(self, (anchor, positive, negative)): return F.triplet_margin_loss(anchor, positive, negative, self.margin, self.p, self.eps, self.swap)
def train_forward(self, x, epoch): img, attr_label, obj_label = x[0], x[1], x[2] neg_attrs, neg_objs = x[4], x[5] # in this way vpos = self.compose(attr_label, obj_label) vneg = self.compose(neg_attrs, neg_objs) if self.args.img_embed==1 or self.args.glove_init ==1: img = self.image_feat(img) loss = F.triplet_margin_loss(img, vpos, vneg, margin=.5) return loss
def TripletLoss(anchor, positive, negatives): """ We found that add all the negative ones together can yeild relatively better performance. """ batch_size, neg_num, embed_size = negatives.size() negatives = negatives.view(neg_num, batch_size, embed_size) losses = 0 for idx, negative in enumerate(negatives): losses += torch.mean(F.triplet_margin_loss(anchor, positive, negative)) return losses / (idx + 1)
def train_model(num_epochs, optim_name=""): model = DeepRank() if torch.cuda.is_available(): model.cuda() if optim_name == "adam": optimizer = optim.Adam(model.parameters(), lr=LEARNING_RATE) elif optim_name == "rms": optimizer = optim.RMSprop(model.parameters(), lr=LEARNING_RATE) else: optimizer = optim.SGD(model.parameters(), lr=LEARNING_RATE, momentum=0.9, nesterov=True) print(f'==> Selected optimizer : {optim_name}') model.train() # set to training mode start_time = time.time() for epoch in range(num_epochs): print(f'Epoch {epoch + 1}/{num_epochs}') print('-' * 10) running_loss = [] for batch_idx, (Q, P, N) in enumerate(train_loader): if torch.cuda.is_available(): Q, P, N = Variable(Q).cuda(), Variable(P).cuda(), Variable(N).cuda() else: Q, P, N = Variable(Q), Variable(P), Variable(N) # set gradient to 0 optimizer.zero_grad() Q_embedding, P_embedding, N_embedding = model(Q), model(P), model(N) # get triplet loss loss = F.triplet_margin_loss(anchor=Q_embedding, positive=P_embedding, negative=N_embedding) # back-propagate & optimize loss.backward() optimizer.step() # calculate loss running_loss.append(loss.item()) print(f'\t--> epoch{epoch+1} {100 * batch_idx / len(train_loader):.2f}% done... loss : {loss:.4f}') epoch_loss = np.mean(running_loss) print(f'epoch{epoch+1} average loss: {epoch_loss:.2f}') finish_time = time.time() print(f'elapsed time : {time.strftime("%H:%M:%S", time.gmtime(finish_time - start_time))}') torch.save(model.state_dict(), MODEL_PATH) # save model parameters
def test(model, test_loader, epoch, margin, threshlod, is_cuda=True, log_interval=1000): model.eval() test_loss = AverageMeter() accuracy = 0 num_p = 0 total_num = 0 batch_num = len(test_loader) for batch_idx, (data_a, data_p, data_n, target) in enumerate(test_loader): if is_cuda: data_a = data_a.cuda() data_p = data_p.cuda() data_n = data_n.cuda() target = target.cuda() data_a = Variable(data_a, volatile=True) data_p = Variable(data_p, volatile=True) data_n = Variable(data_n, volatile=True) target = Variable(target) out_a = model(data_a) out_p = model(data_p) out_n = model(data_n) loss = F.triplet_margin_loss(out_a, out_p, out_n, margin) dist1 = F.pairwise_distance(out_a, out_p) dist2 = F.pairwise_distance(out_a, out_n) #print('dist1', dist1) #print('dist2',dist2) #print('threshlod', threshlod) num = ((dist1 < threshlod).sum() + (dist2 > threshlod).sum()).data[0] num_p += num num_p = 1.0 * num_p total_num += data_a.size()[0] * 2 #print('num--num_p -- total', num, num_p , total_num) test_loss.update(loss.data[0]) if (batch_idx + 1) % log_interval == 0: accuracy_tmp = num_p / total_num print('Test- Epoch {:04d}\tbatch:{:06d}/{:06d}\tAccuracy:{:.04f}\tloss:{:06f}'\ .format(epoch, batch_idx+1, batch_num, accuracy_tmp, test_loss.avg)) test_loss.reset() accuracy = num_p / total_num return accuracy
def epoch(self): self.model.train() for (anchor, positive, negative) in self.dataloader: if anchor.size()[0] != self.batch_size: continue if self.cuda: anchor = anchor.cuda() positive = positive.cuda() negative = negative.cuda() z_anchor = self.model(anchor) z_positive = self.model(positive) select_triplets = self.use_hard_triplets if self.use_hard_triplets and self.current_epoch < 4: negative = negative[:, 0] select_triplets = False if select_triplets: ext_anchor = torch.stack([z_anchor], dim=1) n_neg_samples = negative.size(1) all_negs = negative.view(-1, negative.size(2), negative.size(3), negative.size(4)) z_all_negs = self.model(all_negs).view(self.batch_size, n_neg_samples, z_anchor.size(1)) distances = torch.sum(torch.pow(ext_anchor - z_all_negs, 2), dim=2) hard_sample_positions = torch.argmin(distances, dim=1).type(torch.int64) batch_indices = torch.arange(0, self.batch_size).type(torch.int64) if self.cuda: batch_indices = batch_indices.cuda() z_negative = z_all_negs[batch_indices, hard_sample_positions] else: z_negative = self.model(negative) loss = F.triplet_margin_loss(z_anchor, z_positive, z_negative, margin=self.margin) self.optimizer.zero_grad() loss.backward() self.optimizer.step() return { "epoch": self.current_epoch, "losses": { "loss": loss.detach().item(), }, "networks": { "model": self.model, }, "optimizers": { "optimizer": self.optimizer, } }
def get_loss1_loss2_food(model, img_loader, device, data_loader_wo_img): model.eval() loss1 = [] loss2 = [] score_dict = get_all_rank_scores_lfw(model, img_loader, device) with torch.no_grad(): for anc, pos, neg, _ in data_loader_wo_img: anc, pos, neg = anc.to(device), pos.to(device), neg.to(device) for a, p, n in zip(anc, pos, neg): output1 = score_dict[int(a)].view(1,-1) output2 = score_dict[int(p)].view(1,-1) output3 = score_dict[int(n)].view(1,-1) l1 = triplet_margin_loss(output1, output2, output3, margin=1.0, p=2) l2 = triplet_margin_loss(output1, output3, output2, margin=1.0, p=2) loss1.append(l1.item()) loss2.append(l2.item()) return np.array(loss1, dtype=float), np.array(loss2, dtype=float)
def _get_loss(self, output_anchors, output_positives, output_negatives, B, loss_type): L = output_anchors.size(-1) if (loss_type=='triplet'): output_anchors = output_anchors.unsqueeze(1).expand_as(output_negatives).contiguous().view(-1, L) output_positives = output_positives.unsqueeze(1).expand_as(output_negatives).contiguous().view(-1, L) output_negatives = output_negatives.contiguous().view(-1, L) loss = F.triplet_margin_loss(output_anchors, output_positives, output_negatives, margin=self.margin, p=2, reduction='mean') elif (loss_type=='sare_joint'): dist_pos = torch.mm(output_anchors, output_positives.transpose(0,1)) # B*B dist_pos = dist_pos.diagonal(0) dist_pos = dist_pos.view(B, 1) output_anchors = output_anchors.unsqueeze(1).expand_as(output_negatives).contiguous().view(-1, L) output_negatives = output_negatives.contiguous().view(-1, L) dist_neg = torch.mm(output_anchors, output_negatives.transpose(0,1)) # B*B dist_neg = dist_neg.diagonal(0) dist_neg = dist_neg.view(B, -1) # joint optimize dist = torch.cat((dist_pos, dist_neg), 1)/self.temp[0] dist = F.log_softmax(dist, 1) loss = (- dist[:, 0]).mean() elif (loss_type=='sare_ind'): dist_pos = torch.mm(output_anchors, output_positives.transpose(0,1)) # B*B dist_pos = dist_pos.diagonal(0) dist_pos = dist_pos.view(B, 1) output_anchors = output_anchors.unsqueeze(1).expand_as(output_negatives).contiguous().view(-1, L) output_negatives = output_negatives.contiguous().view(-1, L) dist_neg = torch.mm(output_anchors, output_negatives.transpose(0,1)) # B*B dist_neg = dist_neg.diagonal(0) dist_neg = dist_neg.view(B, -1) # indivial optimize dist_neg = dist_neg.unsqueeze(2) dist_pos = dist_pos.view(B, 1, 1).expand_as(dist_neg) dist = torch.cat((dist_pos, dist_neg), 2).view(-1, 2)/self.temp[0] dist = F.log_softmax(dist, 1) loss = (- dist[:, 0]).mean() else: assert ("Unknown loss function") return loss
def anticipation_loss(self, frame_feats, lstm_feats, batch): B = frame_feats.shape[0] T = lstm_feats.shape[1] positive, negative = batch['positive'], batch['negative'] target, length = batch['verb'], batch['length'] # select the active frame from the clip lstm_preds = self.fc(lstm_feats) # (B, max_L, #classes) lstm_preds = lstm_preds.view(B * T, -1) target_flat = target.unsqueeze(1).expand(target.shape[0], T).contiguous().view(-1) pred_scores = -F.cross_entropy( lstm_preds, target_flat, reduction='none').view(B, T) _, frame_idx = pred_scores.max(1) frame_idx = torch.min(frame_idx, length - 1) # don't select a padding frame! active_feats = frame_feats[torch.arange(B), frame_idx] # (B, 256, 28, 28) active_pooled = self.pool(active_feats).view(B, -1) def embed(x): pred_frame = self.project(self.backbone(x)) pooled = self.pool(pred_frame).view(B, -1) return pooled positive_pooled = embed(positive) _, (hn, cn) = self.rnn(positive_pooled.unsqueeze(1), self.get_hidden_state(B, positive_pooled.device)) preds = self.fc(hn[-1]) aux_loss = F.cross_entropy(preds, target, reduction='none') if self.ant_loss == 'mse': ant_loss = 0.1 * ((positive_pooled - active_pooled)**2).mean(1) elif self.ant_loss == 'triplet': negative_pooled = self.backbone(negative) negative_pooled = self.pool(negative_pooled).view(B, -1) anc, pos, neg = F.normalize(positive_pooled, 2), F.normalize( active_pooled, 2), F.normalize(negative_pooled, 2) ant_loss = F.triplet_margin_loss(anc, pos, neg, margin=0.5, reduction='none') return {'ant_loss': ant_loss, 'aux_loss': aux_loss}
def train_epoch(net, dset, optimizer, batch_size, epoch_index, out_dir, log=None): gen_triplets(dset) net.train() batch_num = dset.num_triplets // batch_size pbar = tqdm(range(batch_num)) for batch_idx in pbar: batch_start = batch_size * batch_idx patches = get_batch(dset, batch_start, batch_size) pos_a, pos_b, neg = tuple(wrap_torch(p) for p in patches) # output loss on descriptors loss = F.triplet_margin_loss( net(pos_a), net(pos_b), net(neg), margin=LOSS_MARGIN, swap=LOSS_ANCHORSWAP, ) # compute gradient and update weights optimizer.zero_grad() loss.backward() optimizer.step() learn_rate_decay(optimizer) log.write('loss', loss.data[0]) log.advance_time() if batch_idx % 100 == 0: pbar.set_description( 'Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}'.format( epoch_index, batch_start, dset.num_triplets, 100. * batch_idx / batch_num, loss.data[0])) chkpt_file = os.path.join(out_dir, CHECKPOINT_FORMAT.format(n=epoch_index)) save_weights(net, chkpt_file, extra_attrs=dict(epoch=epoch_index))