Esempio n. 1
0
    def meta_train(self, examplar, search, gt_cls, gt_loc, gt_loc_weight):
        examplar = self.backbone(examplar)
        search = self.backbone(search)
        if cfg.ADJUST.USE:
            examplar = self.neck(examplar)
            search = self.neck(search)
        # first iter
        pred_cls, pred_loc = self.rpn(examplar, search, self.init_weight,
                                      self.bn_weight)
        pred_cls = self.log_softmax(pred_cls)
        cls_loss = select_cross_entropy_loss(pred_cls, gt_cls)
        loc_loss = weight_l1_loss(pred_loc, gt_loc, gt_loc_weight)
        total_loss = cfg.TRAIN.CLS_WEIGHT * cls_loss + cfg.TRAIN.LOC_WEIGHT * loc_loss

        grads = torch.autograd.grad(total_loss,
                                    self.init_weight.values(),
                                    retain_graph=True,
                                    create_graph=True)
        new_init_weight = OrderedDict((k, iw - a * g) for (k, iw), a, g in zip(
            self.init_weight.items(), self.alpha.values(), grads))
        # second iter
        pred_cls, pred_loc = self.rpn(examplar, search, new_init_weight,
                                      self.bn_weight)
        pred_cls = self.log_softmax(pred_cls)
        cls_loss = select_cross_entropy_loss(pred_cls, gt_cls)
        loc_loss = weight_l1_loss(pred_loc, gt_loc, gt_loc_weight)
        total_loss = cfg.TRAIN.CLS_WEIGHT * cls_loss + cfg.TRAIN.LOC_WEIGHT * loc_loss
        grads = torch.autograd.grad(total_loss,
                                    new_init_weight.values(),
                                    create_graph=True)
        new_init_weight = OrderedDict((k, iw - a * g) for (k, iw), a, g in zip(
            new_init_weight.items(), self.alpha.values(), grads))
        return new_init_weight
Esempio n. 2
0
 def meta_eval(self, new_init_weight, examplar, search, gt_cls, gt_loc,
               gt_loc_weight):
     examplar = self.backbone(examplar)
     search = self.backbone(search)
     if cfg.ADJUST.USE:
         examplar = self.neck(examplar)
         search = self.neck(search)
     pred_cls, pred_loc = self.rpn(examplar, search, new_init_weight,
                                   self.bn_weight)
     pred_cls = self.log_softmax(pred_cls)
     cls_loss = select_cross_entropy_loss(pred_cls, gt_cls)
     loc_loss = weight_l1_loss(pred_loc, gt_loc, gt_loc_weight)
     total_loss = cfg.TRAIN.CLS_WEIGHT * cls_loss + cfg.TRAIN.LOC_WEIGHT * loc_loss
     # compute the loss of init examplar
     init_grad_vals = torch.autograd.grad(total_loss,
                                          self.init_weight.values(),
                                          retain_graph=True)
     alpha_grad_vals = torch.autograd.grad(total_loss,
                                           self.alpha.values(),
                                           retain_graph=True)
     # generate ordered dict
     init_grads = OrderedDict(
         (k, g) for k, g in zip(self.init_weight.keys(), init_grad_vals))
     alpha_grads = OrderedDict(
         (k, g) for k, g in zip(self.alpha.keys(), alpha_grad_vals))
     return init_grads, alpha_grads, total_loss
Esempio n. 3
0
    def forward(self, examplar, train_search, train_gt_cls, train_gt_loc,
                train_gt_loc_weight, test_search, test_gt_cls, test_gt_loc,
                test_gt_loc_weight):
        examplar = self.backbone(examplar)
        search = self.backbone(train_search)
        test_search = self.backbone(test_search)
        if cfg.ADJUST.USE:
            examplar = self.neck(examplar)
            search = self.neck(search)
            test_search = self.neck(test_search)

        loc_examplar = examplar.detach()
        new_examplar = examplar
        new_examplar.requires_grad_(True)
        pred_cls, pred_loc = self.rpn(new_examplar, search)
        pred_cls = self.log_softmax(pred_cls)
        init_cls_loss = select_cross_entropy_loss(pred_cls, train_gt_cls)
        init_loc_loss = weight_l1_loss(pred_loc, train_gt_loc,
                                       train_gt_loc_weight)
        init_total_loss = cfg.TRAIN.CLS_WEIGHT * init_cls_loss + cfg.TRAIN.LOC_WEIGHT * init_loc_loss
        examplar_grad = torch.autograd.grad(init_cls_loss,
                                            new_examplar)[0] * 1000
        new_examplar = new_examplar + self.grad_layer(examplar_grad)

        pred_cls, _ = self.rpn(new_examplar, test_search)
        pred_cls = self.log_softmax(pred_cls)
        cls_loss = select_cross_entropy_loss(pred_cls, test_gt_cls)

        _, pred_loc = self.rpn(loc_examplar, test_search)
        loc_loss = weight_l1_loss(pred_loc, test_gt_loc, test_gt_loc_weight)
        total_loss = cfg.TRAIN.CLS_WEIGHT * cls_loss + cfg.TRAIN.LOC_WEIGHT * loc_loss
        return {
            'cls_loss': cls_loss,
            'loc_loss': loc_loss,
            'total_loss': total_loss,
            'init_cls_loss': init_cls_loss,
            'init_loc_loss': init_loc_loss,
            'init_total_loss': init_total_loss,
            'examplar_grad': examplar_grad
        }
Esempio n. 4
0
    def set_examplar(self, examplar, search, gt_cls, gt_loc, gt_loc_weight):
        examplar = self.backbone(examplar)
        search = self.backbone(search)
        if cfg.ADJUST.USE:
            examplar = self.neck(examplar)
            search = self.neck(search)
        self.examplar = examplar
        self.loc_examplar = examplar.detach()
        self.examplar.requires_grad_(True)

        pred_cls, pred_loc = self.rpn(self.examplar, search)
        pred_cls = self.log_softmax(pred_cls)
        cls_loss = select_cross_entropy_loss(pred_cls, gt_cls)
        # backward for the grad
        examplar_grad = torch.autograd.grad(cls_loss, self.examplar)[0] * 1000
        self.examplar = self.examplar + self.grad_layer(examplar_grad)
Esempio n. 5
0
 def forward(self, examplar, search, gt_cls, gt_loc, gt_loc_weight):
     # normal forward
     examplar = self.backbone(examplar)
     search = self.backbone(search)
     if cfg.ADJUST.USE:
         examplar = self.neck(examplar)
         search = self.neck(search)
     pred_cls, pred_loc = self.rpn(examplar, search)
     pred_cls = self.log_softmax(pred_cls)
     cls_loss = select_cross_entropy_loss(pred_cls, gt_cls)
     loc_loss = weight_l1_loss(pred_loc, gt_loc, gt_loc_weight)
     total_loss = cfg.TRAIN.CLS_WEIGHT * cls_loss + cfg.TRAIN.LOC_WEIGHT * loc_loss
     return {
         'cls_loss': cls_loss,
         'loc_loss': loc_loss,
         'total_loss': total_loss
     }