def meta_train(self, examplar, search, gt_cls, gt_loc, gt_loc_weight): examplar = self.backbone(examplar) search = self.backbone(search) if cfg.ADJUST.USE: examplar = self.neck(examplar) search = self.neck(search) # first iter pred_cls, pred_loc = self.rpn(examplar, search, self.init_weight, self.bn_weight) pred_cls = self.log_softmax(pred_cls) cls_loss = select_cross_entropy_loss(pred_cls, gt_cls) loc_loss = weight_l1_loss(pred_loc, gt_loc, gt_loc_weight) total_loss = cfg.TRAIN.CLS_WEIGHT * cls_loss + cfg.TRAIN.LOC_WEIGHT * loc_loss grads = torch.autograd.grad(total_loss, self.init_weight.values(), retain_graph=True, create_graph=True) new_init_weight = OrderedDict((k, iw - a * g) for (k, iw), a, g in zip( self.init_weight.items(), self.alpha.values(), grads)) # second iter pred_cls, pred_loc = self.rpn(examplar, search, new_init_weight, self.bn_weight) pred_cls = self.log_softmax(pred_cls) cls_loss = select_cross_entropy_loss(pred_cls, gt_cls) loc_loss = weight_l1_loss(pred_loc, gt_loc, gt_loc_weight) total_loss = cfg.TRAIN.CLS_WEIGHT * cls_loss + cfg.TRAIN.LOC_WEIGHT * loc_loss grads = torch.autograd.grad(total_loss, new_init_weight.values(), create_graph=True) new_init_weight = OrderedDict((k, iw - a * g) for (k, iw), a, g in zip( new_init_weight.items(), self.alpha.values(), grads)) return new_init_weight
def meta_eval(self, new_init_weight, examplar, search, gt_cls, gt_loc, gt_loc_weight): examplar = self.backbone(examplar) search = self.backbone(search) if cfg.ADJUST.USE: examplar = self.neck(examplar) search = self.neck(search) pred_cls, pred_loc = self.rpn(examplar, search, new_init_weight, self.bn_weight) pred_cls = self.log_softmax(pred_cls) cls_loss = select_cross_entropy_loss(pred_cls, gt_cls) loc_loss = weight_l1_loss(pred_loc, gt_loc, gt_loc_weight) total_loss = cfg.TRAIN.CLS_WEIGHT * cls_loss + cfg.TRAIN.LOC_WEIGHT * loc_loss # compute the loss of init examplar init_grad_vals = torch.autograd.grad(total_loss, self.init_weight.values(), retain_graph=True) alpha_grad_vals = torch.autograd.grad(total_loss, self.alpha.values(), retain_graph=True) # generate ordered dict init_grads = OrderedDict( (k, g) for k, g in zip(self.init_weight.keys(), init_grad_vals)) alpha_grads = OrderedDict( (k, g) for k, g in zip(self.alpha.keys(), alpha_grad_vals)) return init_grads, alpha_grads, total_loss
def forward(self, examplar, train_search, train_gt_cls, train_gt_loc, train_gt_loc_weight, test_search, test_gt_cls, test_gt_loc, test_gt_loc_weight): examplar = self.backbone(examplar) search = self.backbone(train_search) test_search = self.backbone(test_search) if cfg.ADJUST.USE: examplar = self.neck(examplar) search = self.neck(search) test_search = self.neck(test_search) loc_examplar = examplar.detach() new_examplar = examplar new_examplar.requires_grad_(True) pred_cls, pred_loc = self.rpn(new_examplar, search) pred_cls = self.log_softmax(pred_cls) init_cls_loss = select_cross_entropy_loss(pred_cls, train_gt_cls) init_loc_loss = weight_l1_loss(pred_loc, train_gt_loc, train_gt_loc_weight) init_total_loss = cfg.TRAIN.CLS_WEIGHT * init_cls_loss + cfg.TRAIN.LOC_WEIGHT * init_loc_loss examplar_grad = torch.autograd.grad(init_cls_loss, new_examplar)[0] * 1000 new_examplar = new_examplar + self.grad_layer(examplar_grad) pred_cls, _ = self.rpn(new_examplar, test_search) pred_cls = self.log_softmax(pred_cls) cls_loss = select_cross_entropy_loss(pred_cls, test_gt_cls) _, pred_loc = self.rpn(loc_examplar, test_search) loc_loss = weight_l1_loss(pred_loc, test_gt_loc, test_gt_loc_weight) total_loss = cfg.TRAIN.CLS_WEIGHT * cls_loss + cfg.TRAIN.LOC_WEIGHT * loc_loss return { 'cls_loss': cls_loss, 'loc_loss': loc_loss, 'total_loss': total_loss, 'init_cls_loss': init_cls_loss, 'init_loc_loss': init_loc_loss, 'init_total_loss': init_total_loss, 'examplar_grad': examplar_grad }
def set_examplar(self, examplar, search, gt_cls, gt_loc, gt_loc_weight): examplar = self.backbone(examplar) search = self.backbone(search) if cfg.ADJUST.USE: examplar = self.neck(examplar) search = self.neck(search) self.examplar = examplar self.loc_examplar = examplar.detach() self.examplar.requires_grad_(True) pred_cls, pred_loc = self.rpn(self.examplar, search) pred_cls = self.log_softmax(pred_cls) cls_loss = select_cross_entropy_loss(pred_cls, gt_cls) # backward for the grad examplar_grad = torch.autograd.grad(cls_loss, self.examplar)[0] * 1000 self.examplar = self.examplar + self.grad_layer(examplar_grad)
def forward(self, examplar, search, gt_cls, gt_loc, gt_loc_weight): # normal forward examplar = self.backbone(examplar) search = self.backbone(search) if cfg.ADJUST.USE: examplar = self.neck(examplar) search = self.neck(search) pred_cls, pred_loc = self.rpn(examplar, search) pred_cls = self.log_softmax(pred_cls) cls_loss = select_cross_entropy_loss(pred_cls, gt_cls) loc_loss = weight_l1_loss(pred_loc, gt_loc, gt_loc_weight) total_loss = cfg.TRAIN.CLS_WEIGHT * cls_loss + cfg.TRAIN.LOC_WEIGHT * loc_loss return { 'cls_loss': cls_loss, 'loc_loss': loc_loss, 'total_loss': total_loss }