Example #1
0
File: loss.py Project: whjzsy/TACT
 def get_rcnn_loss(self, cf, op, bb, br, gb):
     # cf = [numb, numbb, 1+nnum, 2(pn)] (output binary class)
     # op = [numb, numbb, 1] (output iou overlap score)
     # bb = [numb, numbb, 4(xyxy)] (output refined bbox)
     # br = [numb, numbb, 4(xyxy)] (output unrefined bbox)
     # gb = [numb, 4] (ground truth bbox)
     # sizes
     num_batch = cf.shape[0]
     num_boxes = cf.shape[1]
     num_negbb = cf.shape[2]-1
     # per batch iteration
     loss, total_pos = 0,0
     for i in range(num_batch):
         # find positive instances in a batch (bb overlap > threshold)
         cf_i = cf[i]                  # [numbox, 1+nnum, 2]
         op_i = op[i] if self.head_oproi else None   # [numbox, 1]
         bb_i = bb[i]                  # [numbox, 4] = [numbox, x0y0x1y1]
         br_i = br[i]                  # [numbox, 4] = [numbox, x0y0x1y1]
         gb_i = gb[i].unsqueeze(0)     # [1,4] = [1, x0y0x1y1]
         # iou for rois
         iou_br = jaccard(gb_i, br_i)[0]
         pos_idxs = (iou_br >=self.bbox_thres[0]).nonzero()[:,0]
         neg_idxs = (iou_br < self.bbox_thres[1]).nonzero()[:,0]
         pos_nums, neg_nums = pos_idxs.shape[0], neg_idxs.shape[0]
         total_pos += pos_nums
         
         # enforce iou overlap regression loss
         loss_op_i = self.op_loss(op_i[...,0][pos_idxs], iou_br[pos_idxs]) if (pos_nums>0) and (self.head_oproi) else 0.
         
         # enforce labels, binary cross entropy loss
         # pos input sample ~ pos/neg boxes
         cf_lbl_pos = torch.zeros(pos_nums, device=cf_i.device).long()
         cf_lbl_neg  = torch.ones(neg_nums, device=cf_i.device).long()
         loss_cf_i_pos_pos = self.cf_loss(cf_i[pos_idxs,0,:], cf_lbl_pos) if pos_nums>0 else 0.
         loss_cf_i_pos_neg = self.cf_loss(cf_i[neg_idxs,0,:], cf_lbl_neg) if neg_nums>0 else 0.
         loss_cf_i_pos = loss_cf_i_pos_pos + loss_cf_i_pos_neg
         
         # neg input sample ~ pos boxes
         if (num_negbb>0) and (pos_nums>0):
             cf_i_neg = cf_i[pos_idxs,1:,:].flatten(0,1) # [pos_nums*num_negbb, 2]
             cf_lbl_neg = torch.ones(pos_nums*num_negbb, device=cf_i.device).long()
             loss_cf_i_neg = self.cf_loss(cf_i_neg, cf_lbl_neg)
         else:
             loss_cf_i_neg = 0.
         loss_cf_i = loss_cf_i_pos + loss_cf_i_neg
         
         # iou for refined bb
         iou_bb = jaccard(gb_i, bb_i, eps=1.0)[0]
         # enforce box regression (only for positive instances), linear iou loss
         loss_bb_i = torch.mean(1. - iou_bb[pos_idxs]) if pos_nums>0 else 0
         # loss for single batch, add to total loss
         if pos_nums==0:
             loss_i = 0.
         else:
             loss_i = loss_cf_i + loss_bb_i + loss_op_i
         loss += loss_i
     
     # divide loss by batch size
     loss /= num_batch
     return loss, total_pos
Example #2
0
    def cc_fast_nms(
        self,
        boxes,
        masks,
        scores,
        iou_threshold: float = 0.5,
        top_k: int = 200,
    ):
        # Collapse all the classes into 1
        scores, classes = scores.max(dim=0)

        _, idx = scores.sort(0, descending=True)
        idx = idx[:top_k]

        boxes_idx = boxes[idx]

        iou = jaccard(boxes_idx, boxes_idx)

        iou.triu_(diagonal=1)

        iou_max, _ = torch.max(iou, dim=0)

        idx_out = idx[iou_max <= iou_threshold]

        return (
            boxes[idx_out],
            masks[idx_out],
            classes[idx_out],
            scores[idx_out],
        )
Example #3
0
    def default_gt_match(self, gt_box: [torch.Tensor], default_box: torch.Tensor):

        default_box = torch.clamp(default_box, min=0,
                                  max=cfg.right_border)

        batch_offset = []
        batch_label = []

        for index in range(len(gt_box)):

            gt_box_x = gt_box[index]
            # tmp = gt_box_x[:, -1]
            keep = gt_box_x[:, -1] >= 0
            gt_box_x = gt_box_x[keep]
            # gt_label = gt_box_x[:, -1]
            default_label = torch.Tensor(default_box.size()[0]).fill_(-1)
            # default_label = torch.Tensor(default_box.size()[0]).fill_(0)

            gt_box_x = gt_box_x[:, :2].cuda()


            overlap, union, non_overlap,tt = jaccard(gt_box_x, default_box)

            maxlap_of_ground, maxidx_of_ground = overlap.max(1)

            maxlap_of_default, maxidx_of_default = overlap.max(0)

            nonlap_of_default, nonidx_of_default = non_overlap.max(0)

            maxlap_of_default.index_fill_(0, maxidx_of_ground, 2)

            if len(maxidx_of_ground.size()) >= 1:
                for j in range(maxidx_of_ground.size()[0]):
                    # maxlap_of_default[maxidx_of_ground[j]] = maxlap_of_ground[j]
                    maxidx_of_default[maxidx_of_ground[j]] = j
            else:
                maxidx_of_default[maxidx_of_ground] = 0

            # tmp = nonlap_of_default < 39
            # tmp2 = nonlap_of_default >= 39
            # matches = gt_box_x[nonidx_of_default].squeeze()
            ################################
            tmp = maxlap_of_default < cfg.rpn_neg_thresh
            tmp2 = maxlap_of_default >= cfg.rpn_pos_thresh
            matches = gt_box_x[maxidx_of_default].squeeze()
            ################################
            default_label[tmp] = 0
            default_label[tmp2] = 1

            default_gt_offset = box_to_offset(default_box, matches)
            # t=default_gt_offset[default_label>0]
            batch_offset.append(default_gt_offset)
            batch_label.append(default_label.cuda())

        return batch_label, batch_offset
Example #4
0
def fast_nms(box_thre, coef_thre, class_thre, cfg, second_threshold=False):
    class_thre, idx = class_thre.sort(
        1, descending=True)  # [80, 64 (the number of kept boxes)]

    idx = idx[:, :cfg.top_k]
    class_thre = class_thre[:, :cfg.top_k]

    num_classes, num_dets = idx.size()
    box_thre = box_thre[idx.reshape(-1), :].reshape(num_classes, num_dets,
                                                    4)  # [80, 64, 4]
    coef_thre = coef_thre[idx.reshape(-1), :].reshape(num_classes, num_dets,
                                                      -1)  # [80, 64, 32]

    iou = jaccard(box_thre, box_thre)
    iou.triu_(diagonal=1)
    iou_max, _ = iou.max(dim=1)

    # Now just filter out the ones higher than the threshold
    keep = (iou_max <= cfg.nms_iou_thre)

    # We should also only keep detections over the confidence threshold, but at the cost of
    # maxing out your detection count for every image, you can just not do that. Because we
    # have such a minimal amount of computation per detection (matrix mulitplication only),
    # this increase doesn't affect us much (+0.2 mAP for 34 -> 33 fps), so we leave it out.
    # However, when you implement this in your method, you should do this second threshold.
    if second_threshold:
        keep *= (class_thre > cfg.nms_score_thresh)

    # Assign each kept detection to its corresponding class
    class_ids = torch.arange(num_classes,
                             device=box_thre.device)[:, None].expand_as(keep)

    class_ids = class_ids[keep]

    box_nms = box_thre[keep]
    coef_nms = coef_thre[keep]
    class_nms = class_thre[keep]

    # Only keep the top cfg.max_num_detections highest scores across all classes
    class_nms, idx = class_nms.sort(0, descending=True)

    idx = idx[:cfg.max_detections]
    class_nms = class_nms[:cfg.max_detections]

    class_ids = class_ids[idx]
    box_nms = box_nms[idx]
    coef_nms = coef_nms[idx]

    return box_nms, coef_nms, class_ids, class_nms
Example #5
0
    def _calccuracy(self):
        tp = 0
        fp = 0
        fn = 0
        tn = 0
        self.misslist = {}

        overlaps = jaccard(self.targets_box, self.prediction_box)
        # check ground truth(row) 1 by 1
        for i, p_label in enumerate(self.prediction_label):
            # Exist teeth
            # Predicted label ranged 1~32, target label ranged 0~31
            if (p_label-1) in self.targets_label:
                # BBox Conf<0.5 represent for NO SOLID DETECTION
                if self.scores[i][0].item() < 0.5:
                    fn += 1
                    # Log false teeth
                    self.misslist[labels[p_label-1]] = 'FN'
                elif self.scores[i][0].item() >= 0.5:
                    # Add firewall[0] of dirty data ########### <-Problem
                    # Predicted label ranged 1~32, target label ranged 0~31
                    same_label_iou = overlaps[self.targets_label == (p_label-1), i][0]
                    if same_label_iou.item() > 0.5:
                        # If labels shows exactly same, True Positive++
                        tp += 1
                    else:
                        # if labels wrong or iou<0.5, False Positive++
                        fp += 1
                        # Log false teeth
                        self.misslist[labels[p_label - 1]] = 'FPE'
            # Non-Exist teeth
            else:
                if self.scores[i] < 0.5:
                    # If non existence teeth didn't detect, True Negative++
                    tn += 1
                elif self.scores[i] >= 0.5:
                    # If non existence teeth did detect sth(conf>0.5), False Positive++
                    fp += 1
                    # Log false teeth
                    self.misslist[labels[p_label - 1]] = 'FPN'
        # print('exist teeth conf<0.5 count(FN):{}'.format(fn))
        # print('exist teeth conf>0.5 iou>0.5 count(TP):{}'.format(tp))
        # print('exist teeth conf>0.5 iou<0.5, non exist teeth conf>0.5 count(FP):{}'.format(fp))
        # print('non exist teeth conf<0.5 count(TN):{}'.format(tn))
        print('accuracy = ({}+{})/({}+{}+{}+{})'.format(tp, tn, tp, tn, fp, fn))
        # accuracy = (TP+TN)/ALL
        return (tp+tn)/(tp+tn+fp+fn)
Example #6
0
    def fast_nms(
        self,
        boxes,
        masks,
        scores,
        iou_threshold: float = 0.5,
        top_k: int = 200,
        second_threshold: bool = False,
    ):
        scores, idx = scores.sort(1, descending=True)

        idx = idx[:, :top_k].contiguous()
        scores = scores[:, :top_k]

        num_classes, num_dets = idx.size()

        boxes = boxes[idx.view(-1), :].view(num_classes, num_dets, 4)
        masks = masks[idx.view(-1), :].view(num_classes, num_dets, -1)

        iou = jaccard(boxes, boxes)
        iou.triu_(diagonal=1)
        iou_max, _ = iou.max(dim=1)

        keep = iou_max <= iou_threshold

        if second_threshold:
            keep *= scores > self.conf_thresh

        classes = torch.arange(num_classes)[:, None].cuda().expand_as(keep)
        classes = classes[keep]

        boxes = boxes[keep]
        masks = masks[keep]
        scores = scores[keep]

        scores, idx = scores.sort(0, descending=True)
        idx = idx[:cfg.max_num_detections]
        scores = scores[:cfg.max_num_detections]

        classes = classes[idx]
        boxes = boxes[idx]
        masks = masks[idx]

        return boxes, masks, classes, scores
Example #7
0
 def get_feats_xfa(self, x, xb):
     # params
     num_batch = x.shape[0]
     thres, nfeat = self.nft_param
     nbox_num, nbox_thr = self.rpn.boxes.bb_nums, self.rpn.boxes.bb_thres
     # change numof candidate negative boxes
     self.rpn.boxes.bb_nums, self.rpn.boxes.bb_thres = 64, 0.5
     # get pos and neg feats from query img
     xf = self.backbone(self.normalize_tensor(x))
     # roi proposals and feats
     rois, scores, feats, _ = self.rpn(xf, xf, xb, add_box=xb, pool_xf=True)
     xfa_tri = feats[0]
     xfa_pos = feats[2][:, -1]
     yfa = feats[2][:, :-1]
     # negative feature mining inside xf
     if self.head_negff:
         xfa_neg = torch.zeros(num_batch, nfeat, self.head_nfeat,
                               self.roip_size, self.roip_size).cuda()
         for i in range(num_batch):
             # get ious per batch, choose feature idxs with lower iou < thres
             xb_i, roi_i, score_i = xb[i], rois[i][:-1, :], scores[i]
             iou_i = jaccard(xb_i, roi_i)[0]
             idx_sel = torch.nonzero(iou_i < thres)[:, 0]
             idx_sel = idx_sel[:nfeat]
             # if numof features insufficient: repeat last idx
             if len(idx_sel) == 0:
                 continue
             if len(idx_sel) < nfeat:
                 for _ in range(nfeat - len(idx_sel)):
                     idx_sel = torch.cat((idx_sel, idx_sel[[-1]]))
             xfa_neg[i] = yfa[i, idx_sel]
     else:
         xfa_neg = None
     # restore default box nums
     self.rpn.boxes.bb_nums, self.rpn.boxes.bb_thres = nbox_num, nbox_thr
     # return (xfa_tri, xfa_pos, xfa_neg)
     return (xfa_tri, xfa_pos, xfa_neg)
Example #8
0
    def pre_gt_match_uniform(self,
                             proposal,
                             gt_box,
                             training=True,
                             params=None):

        batch_proposal = []
        batch_label = []
        if training == True:
            for index in range(len(proposal)):
                peak = params['peak'][index]
                keep = gt_box[index][:, 2] != -1
                this_gt_label = torch.LongTensor(
                    gt_box[index][keep][:, 2].tolist())
                this_proposal = proposal[index]
                this_gt_box = gt_box[index][keep][:, :2].cuda()
                this_noised_gt_box = self.gt_box_add_noise(this_gt_box)

                this_proposal = torch.cat(
                    [this_proposal, this_noised_gt_box, this_gt_box], 0)
                this_proposal = torch.clamp(this_proposal,
                                            min=cfg.left_border,
                                            max=cfg.right_border)

                overlap, union, nonoverlap, tt = jaccard(
                    this_gt_box, this_proposal)

                nonlap_of_predict, nonlapidx_of_predict = nonoverlap.min(0)
                maxlap_of_predict, maxidx_of_predict = overlap.max(0)
                # maxlap_of_predict = maxlap_of_predict.type(torch.long)
                nonlap_of_predict = nonlap_of_predict.type(torch.long)
                # maxlap_of_ground, maxidx_of_ground = overlap.max(1)
                # this_matches = this_gt_box[maxidx_of_predict]
                this_pre_label = this_gt_label[maxidx_of_predict]
                # this_pre_label = torch.zeros(len(maxidx_of_predict))

                if 1:
                    mapping = torch.zeros(len(this_pre_label)).cuda()
                    for i in range(len(this_pre_label)):
                        start = this_proposal[i][0].item()
                        end = this_proposal[i][1].item()
                        count = 0
                        for j in range(len(peak)):
                            if start <= peak[j] and end >= peak[j]:
                                count += 1
                        if count == 1:
                            mapping[i] = 1
                        elif count > 1:
                            mapping[i] = 2

                    total_1 = nonlap_of_predict <= 54
                    total_2 = mapping == 1
                    keep = total_1 * total_2
                    for i in range(len(keep)):
                        if keep[i] == 1:
                            pass
                            # item = nonlapidx_of_predict[i]
                            # this_pre_label[i] = this_gt_label[item]
                        else:
                            this_pre_label[i] = 0

                this_pre_label = this_pre_label.type(torch.long)
                batch_label.append(this_pre_label.cuda())
                batch_proposal.append(this_proposal)
            return batch_proposal, batch_label
        else:
            peak = params['peak']
            keep = gt_box[:, 2] != -1
            this_gt_label = gt_box[keep][:, 2]
            this_proposal = proposal
            this_gt_box = gt_box[keep][:, :2].cuda()

            this_proposal = torch.clamp(this_proposal,
                                        min=cfg.left_border,
                                        max=cfg.right_border)

            overlap, union, nonoverlap, tt = jaccard(this_gt_box,
                                                     this_proposal)
            maxlap_of_ground, maxidx_of_ground = overlap.max(1)

            maxlap_of_predict, maxidx_of_predict = overlap.max(0)
            nonlap_of_predict, nonlapidx_of_predict = nonoverlap.min(0)
            maxlap_of_predict = maxlap_of_predict.type(torch.long)
            nonlap_of_predict = nonlap_of_predict.type(torch.long)
            this_matches = this_gt_box[maxidx_of_predict]
            this_pre_label = this_gt_label[maxidx_of_predict]
            # this_pre_label = torch.zeros(len(maxidx_of_predict)).cuda()

            mapping = torch.zeros(len(this_pre_label)).cuda()
            for i in range(len(this_pre_label)):
                start = this_proposal[i][0].item()
                end = this_proposal[i][1].item()
                count = 0
                for j in range(len(peak)):
                    if start <= peak[j] and end >= peak[j]:
                        count += 1
                if count == 1:
                    mapping[i] = 1
                elif count > 1:
                    mapping[i] = 2

            total_1 = nonlap_of_predict <= 54
            total_2 = mapping == 1
            keep = total_1 * total_2
            for i in range(len(keep)):
                if keep[i] == 1:
                    pass
                    # item = nonlapidx_of_predict[i]
                    # this_pre_label[i] = this_gt_label[item]
                else:
                    this_pre_label[i] = 0
            this_pre_label = this_pre_label.type(torch.long)

        return this_pre_label
Example #9
0
    def predict_gt_match(self, proposal: list, gt_box: list, flag=0):
        batch_proposal = []
        batch_label = []
        batch_predict_offset = []
        batch_weight = []
        batch_pre_weight = []
        if flag == 0:
            for index in range(len(proposal)):
                keep = gt_box[index][:, 2] != -1
                this_gt_label = gt_box[index][keep][:, 2]
                this_proposal = proposal[index]
                this_gt_box = gt_box[index][keep][:, :2].cuda()
                this_noised_gt_box = self.gt_box_add_noise(this_gt_box)

                this_proposal = torch.cat(
                    [this_proposal, this_noised_gt_box, this_gt_box], 0)
                this_proposal = torch.clamp(this_proposal,
                                            min=cfg.left_border,
                                            max=cfg.right_border)

                overlap = jaccard(this_gt_box, this_proposal)[0]
                maxlap_of_ground, maxidx_of_ground = overlap.max(1,
                                                                 keepdim=True)

                maxlap_of_predict, maxidx_of_predict = overlap.max(
                    0, keepdim=True)
                maxlap_of_ground = maxlap_of_ground.squeeze()
                maxidx_of_ground = maxidx_of_ground.squeeze()
                maxlap_of_predict = maxlap_of_predict.squeeze()
                maxidx_of_predict = maxidx_of_predict.squeeze()

                this_matches = this_gt_box[maxidx_of_predict]
                this_pre_label = this_gt_label[maxidx_of_predict]
                b = maxlap_of_predict < cfg.roi_neg_thresh
                this_pre_label[b] = 0

                each_label_num = [0] * 5
                # for i in range(5):
                #     each_label_num[i] = torch.sum(this_pre_label == i).item()

                keep = this_pre_label != -1
                batch_weight.append(torch.Tensor([1 for i in each_label_num]))
                this_proposal = this_proposal[keep]
                this_pre_label = this_pre_label[keep].type(torch.long)
                this_offset = box_to_offset(this_proposal, this_matches[keep])

                # this_predict_offset = regression_label(this_offset, this_pre_label)

                # this_predict_offset = this_predict_offset.reshape(-1, para.classes * 2)
                this_pre_weight = [(max(each_label_num) + 1) / (i + 1)
                                   for i in each_label_num]
                this_pre_weight = [this_pre_weight[i] for i in this_pre_label]
                this_pre_weight = torch.Tensor(this_pre_weight).view(-1, 1)
                this_pre_weight = torch.cat([this_pre_weight, this_pre_weight],
                                            1)

                batch_label.append(this_pre_label.cuda())
                batch_proposal.append(this_proposal)
                batch_predict_offset.append(this_offset)
                batch_pre_weight.append(this_pre_weight)

            return batch_proposal, batch_label, batch_predict_offset, batch_weight, batch_pre_weight
        else:
            keep = gt_box[:, 2] != -1
            this_gt_label = gt_box[keep][:, 2]
            this_proposal = proposal
            this_gt_box = gt_box[keep][:, :2].cuda()

            this_proposal = torch.clamp(this_proposal,
                                        min=cfg.left_border,
                                        max=cfg.right_border)
            # keep = this_proposal[:, 0] <= this_proposal[:, 1]
            # this_proposal = this_proposal[keep]

            overlap, union, nonoverlap = jaccard(this_gt_box, this_proposal)
            maxlap_of_ground, maxidx_of_ground = overlap.max(1)

            maxlap_of_predict, maxidx_of_predict = overlap.max(0)
            nonlap_of_predict, nonlapidx_of_predict = nonoverlap.min(0)
            #
            this_matches = this_gt_box[maxidx_of_predict]
            this_pre_label = this_gt_label[maxidx_of_predict]

            # a = maxlap_of_predict > cfg.roi_neg_thresh_low
            if cfg.testing_metrics == '80':
                b = maxlap_of_predict <= 0.82
                this_pre_label[b] = 0  #
            elif cfg.testing_metrics == '150ms':
                # this_matches = this_gt_box[nonlapidx_of_predict]
                this_pre_label = this_gt_label[nonlapidx_of_predict]
                b = nonlap_of_predict >= 54
                this_pre_label[b] = 0

        return this_pre_label
Example #10
0
    def RPN_eval(self, data_loader, params):
        self.save_dict = {}
        # self.save_dict['pre_window'] = []
        # self.save_dict['data'] = []
        # self.save_dict['ground_window'] = []
        # self.save_dict['false_window'] = []
        # self.save_dict['false_score'] = []

        self.features = self.features.eval()
        self.RPN = self.RPN.eval()
        tool2 = rpn_tool_d()
        tool2.train_mode = False
        all_proposal = []
        # seed = params['seed']
        epoch = params['epoch']

        info = dict()
        info.setdefault('gt', [])
        info.setdefault('pre', [])
        info.setdefault('tp', 0)
        info.setdefault('fp', 0)
        info.setdefault('tn', 0)
        info.setdefault('fn', 0)
        info.setdefault('gt_bin', [])
        info.setdefault('pre_bin', [])
        # if para.save_data:
        #     save_dict = {}

        for data in data_loader:
            y = data[1]
            x = data[0].cuda()
            r_peaks = data[2]
            nums = data[3]
            with torch.no_grad():
                x1, x2, x3, x4 = self.features(x)
                predict_confidence, box_predict = self.RPN(x1, x2, x3, x4)
                proposal, conf, batch_offset = tool2.get_proposal(predict_confidence, box_predict, y, test=True)
            all_proposal.append([])

            for this_batch in range(len(y)):
                this_ground = y[this_batch][:, :2].cuda()
                gt_label = y[this_batch][:, -1].type(torch.uint8).cuda()
                this_predict = proposal[this_batch]
                this_conf = conf[this_batch].cpu()
                keep3 = this_conf >= 0.5
                sin_r_peak = r_peaks[this_batch]
                # self.process_each_window(this_predict, x[this_batch].view(-1))
                if (keep3.sum().item() > 0):
                    this_predict = this_predict[keep3]
                    this_conf = this_conf[keep3]

                overlaps, union, non_overlap, non_overlap2 = jaccard(this_ground, this_predict)

                maxlap_of_pre, maxidx_of_pre = overlaps.max(0)
                maxlap_of_ground, maxidx_of_ground = overlaps.max(1)

                minlap_of_pre, minidx_of_pre = non_overlap.min(0)
                minlap_of_ground, minidx_of_ground = non_overlap2.min(1)

                params = dict()
                params.setdefault('ground_window', this_ground.cpu())
                params.setdefault('maxlap_of_ground', maxlap_of_ground)
                params.setdefault('maxidx_of_ground', maxidx_of_ground)
                params.setdefault('minlap_of_pre', minlap_of_pre)
                params.setdefault('minidx_of_pre', minidx_of_pre)

                params.setdefault('minlap_of_ground', minlap_of_ground)
                params.setdefault('minidx_of_ground', minidx_of_ground)

                params.setdefault('num', nums[this_batch])

                params.setdefault('peak', sin_r_peak)

                params.setdefault('pre_window', this_predict.cpu())
                params['save'] = self.save_dict
                params['data'] = x[this_batch].cpu().numpy()
                params.setdefault('gt_label', gt_label)
                self.first_process(info, params)

        gt = info.get("gt")
        pre = info.get("pre")
        gt_bin = info.get("gt_bin")
        pre_bin = info.get("pre_bin")

        print("acc:{}".format(accuracy_score(gt_bin, pre_bin)))
        print("precision:{}".format(precision_score(gt_bin, pre_bin)))
        print("recall:{}".format(recall_score(gt_bin, pre_bin)))
        # print("f1-score:{}".format(f1_score(gt_bin, pre_bin)))
        print("confusion:{}".format(confusion_matrix(gt_bin, pre_bin)))
        if accuracy_score(gt_bin, pre_bin) > self.max_acc:
            self.max_acc = accuracy_score(gt_bin, pre_bin)
            self.max_pre = precision_score(gt_bin, pre_bin)
            self.max_recall = recall_score(gt_bin, pre_bin)
        print("acc:{}".format(self.max_acc))
        print("precision:{}".format(self.max_pre))
        print("recall:{}".format(self.max_recall))
        # print("f1-score:{}".format(f1_score(gt_bin, pre_bin)))
        # print("tp:{a} fp:{b} fn:{c}".format(a=info.get("tp"), b=info.get("fp"), c=info.get('fn')))

        tool2.train = True
        self.RPN = self.RPN.train()
        self.features = self.features.train()