Example #1
0
    def __init__(self, backbone='res50', hyps=None):
        super(RetinaNet, self).__init__()
        self.num_classes = int(hyps['num_classes']) + 1
        self.anchor_generator = Anchors(ratios=np.array([0.5, 1, 2]), )
        self.num_anchors = self.anchor_generator.num_anchors
        self.init_backbone(backbone)

        self.fpn = FPN(in_channels_list=self.fpn_in_channels,
                       out_channels=256,
                       top_blocks=LastLevelP6P7(self.fpn_in_channels[-1], 256),
                       use_asff=False)
        self.cls_head = CLSHead(in_channels=256,
                                feat_channels=256,
                                num_stacked=4,
                                num_anchors=self.num_anchors,
                                num_classes=self.num_classes)
        self.reg_head = REGHead(
            in_channels=256,
            feat_channels=256,
            num_stacked=4,
            num_anchors=self.num_anchors,
            num_regress=5  # xywha
        )
        self.loss = IntegratedLoss(func='smooth')
        # self.loss_var = KLLoss()
        self.box_coder = BoxCoder()
 def __init__(self, backbone='res50', num_classes=2, num_refining=1):
     super(STELA, self).__init__()
     self.anchor_generator = Anchors()
     self.num_anchors = self.anchor_generator.num_anchors
     self.init_backbone(backbone)
     self.fpn = FPN(in_channels_list=self.fpn_in_channels,
                    out_channels=256,
                    top_blocks=LastLevelP6P7(self.fpn_in_channels[-1], 256))
     self.cls_head = CLSHead(in_channels=256,
                             feat_channels=256,
                             num_stacked=1,
                             num_anchors=self.num_anchors,
                             num_classes=num_classes)
     self.reg_head = REGHead(in_channels=256,
                             feat_channels=256,
                             num_stacked=1,
                             num_anchors=self.num_anchors,
                             num_regress=5)
     self.num_refining = num_refining
     if self.num_refining > 0:
         self.ref_heads = nn.ModuleList([
             REGHead(in_channels=256,
                     feat_channels=256,
                     num_stacked=1,
                     num_anchors=self.num_anchors,
                     num_regress=5) for _ in range(self.num_refining)
         ])
         self.loss_ref = RegressLoss(func='smooth')
     self.loss_cls = FocalLoss()
     self.loss_reg = RegressLoss(func='smooth')
     self.box_coder = BoxCoder()
Example #3
0
 def __init__(self, func='smooth'):
     super(RegressLoss, self).__init__()
     self.box_coder = BoxCoder()
     if func == 'smooth':
         self.criteron = smooth_l1_loss
     elif func == 'mse':
         self.criteron = F.mse_loss
     elif func == 'balanced':
         self.criteron = balanced_l1_loss
Example #4
0
class RegressLoss(nn.Module):
    def __init__(self, func='smooth'):
        super(RegressLoss, self).__init__()
        self.box_coder = BoxCoder()
        if func == 'smooth':
            self.criteron = smooth_l1_loss
        elif func == 'mse':
            self.criteron = F.mse_loss
        elif func == 'balanced':
            self.criteron = balanced_l1_loss
        else:
            raise NotImplementedError
            
    def forward(self, regressions, anchors, annotations, iou_thres=0.5):
        losses = []
        batch_size = regressions.shape[0]
        all_pred_boxes = self.box_coder.decode(anchors, regressions, mode='xywht')
        for j in range(batch_size):
            regression = regressions[j, :, :]
            bbox_annotation = annotations[j, :, :]
            bbox_annotation = bbox_annotation[bbox_annotation[:, -1] != -1]
            pred_boxes = all_pred_boxes[j, :, :]
            if bbox_annotation.shape[0] == 0:
                losses.append(torch.tensor(0).float().cuda())
                continue
            indicator = bbox_overlaps(
                min_area_square(anchors[j, :, :]),
                min_area_square(bbox_annotation[:, :-1])
            )
            overlaps = rbox_overlaps(
                anchors[j, :, :].cpu().numpy(),
                bbox_annotation[:, :-1].cpu().numpy(),
                indicator.cpu().numpy(),
                thresh=1e-1
            )
            if not torch.is_tensor(overlaps):
                overlaps = torch.from_numpy(overlaps).cuda()

            iou_max, iou_argmax = torch.max(overlaps, dim=1)
            positive_indices = torch.ge(iou_max, iou_thres)
            # MaxIoU assigner
            max_gt, argmax_gt = overlaps.max(0) 
            if (max_gt < iou_thres).any():
                positive_indices[argmax_gt[max_gt < iou_thres]]=1

            assigned_annotations = bbox_annotation[iou_argmax, :]
            if positive_indices.sum() > 0:
                all_rois = anchors[j, positive_indices, :]
                gt_boxes = assigned_annotations[positive_indices, :]
                targets = self.box_coder.encode(all_rois, gt_boxes)
                loss = self.criteron(regression[positive_indices, :], targets)
                losses.append(loss)
            else:
                losses.append(torch.tensor(0).float().cuda())
        return torch.stack(losses).mean(dim=0, keepdim=True)
Example #5
0
 def __init__(self, alpha=0.25, gamma=2.0, func = 'smooth'):
     super(IntegratedLoss, self).__init__()
     self.alpha = alpha
     self.gamma = gamma
     self.box_coder = BoxCoder()
     if func == 'smooth':
         self.criteron = smooth_l1_loss
     elif func == 'mse':
         self.criteron = F.mse_loss
     elif func == 'balanced':
         self.criteron = balanced_l1_loss
Example #6
0
 def __init__(self,backbone='res50',hyps=None):
     super(RetinaNet, self).__init__()
     self.num_refining = int(hyps['num_refining'])
     self.num_classes  = int(hyps['num_classes']) + 1
     self.anchor_generator = Anchors(
         ratios = np.array([ 0.5,1, 2]),
         # scales = np.array([ 2**0, 2 ** 0.333, 2 ** 0.666]),
         # rotations = np.array([0, 60, 120])
         )
     self.num_anchors = self.anchor_generator.num_anchors
     self.init_backbone(backbone)
     
     self.fpn = FPN(
         in_channels_list=self.fpn_in_channels,
         out_channels=256,
         top_blocks=LastLevelP6P7(self.fpn_in_channels[-1], 256),
         use_asff = False
     )
     self.cls_head = CLSHead(
         in_channels=256,
         feat_channels=256,
         num_stacked=4,      
         num_anchors=self.num_anchors,
         num_classes=self.num_classes
     )
     self.reg_head = REGHead(
         in_channels=256,
         feat_channels=256,
         num_stacked=4,
         num_anchors=self.num_anchors,
         num_regress=5   # xywha
     )     
     if self.num_refining > 0:
         self.ref_heads = nn.ModuleList(
             [REGHead(
                 in_channels=256,
                 feat_channels=256,
                 num_stacked=2,
                 num_anchors=self.num_anchors,
                 num_regress=5,
             ) for _ in range(self.num_refining)]
         )
         self.loss_ref = RegressLoss(func='smooth')   
     self.loss = IntegratedLoss(func='smooth')
     self.box_coder = BoxCoder()
Example #7
0
def run(args):
    train_loader, val_loader = get_data_loaders(args.dir, args.batch_size,
                                                args.num_workers)

    if args.seed is not None:
        torch.manual_seed(args.seed)

    device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

    num_classes = CityscapesDataset.num_instance_classes() + 1
    model = models.box2pix(num_classes=num_classes)
    model.init_from_googlenet()

    writer = create_summary_writer(model, train_loader, args.log_dir)

    if torch.cuda.device_count() > 1:
        print("Using %d GPU(s)" % torch.cuda.device_count())
        model = nn.DataParallel(model)

    model = model.to(device)

    semantics_criterion = nn.CrossEntropyLoss(ignore_index=255)
    offsets_criterion = nn.MSELoss()
    box_criterion = BoxLoss(num_classes, gamma=2)
    multitask_criterion = MultiTaskLoss().to(device)

    box_coder = BoxCoder()
    optimizer = optim.Adam([{
        'params': model.parameters(),
        'weight_decay': 5e-4
    }, {
        'params': multitask_criterion.parameters()
    }],
                           lr=args.lr)

    if args.resume:
        if os.path.isfile(args.resume):
            print("Loading checkpoint '{}'".format(args.resume))
            checkpoint = torch.load(args.resume)
            args.start_epoch = checkpoint['epoch']
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            multitask_criterion.load_state_dict(checkpoint['multitask'])
            print("Loaded checkpoint '{}' (Epoch {})".format(
                args.resume, checkpoint['epoch']))
        else:
            print("No checkpoint found at '{}'".format(args.resume))

    def _prepare_batch(batch, non_blocking=True):
        x, instance, boxes, labels = batch

        return (convert_tensor(x, device=device, non_blocking=non_blocking),
                convert_tensor(instance,
                               device=device,
                               non_blocking=non_blocking),
                convert_tensor(boxes, device=device,
                               non_blocking=non_blocking),
                convert_tensor(labels,
                               device=device,
                               non_blocking=non_blocking))

    def _update(engine, batch):
        model.train()
        optimizer.zero_grad()
        x, instance, boxes, labels = _prepare_batch(batch)
        boxes, labels = box_coder.encode(boxes, labels)

        loc_preds, conf_preds, semantics_pred, offsets_pred = model(x)

        semantics_loss = semantics_criterion(semantics_pred, instance)
        offsets_loss = offsets_criterion(offsets_pred, instance)
        box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds,
                                            labels)

        loss = multitask_criterion(semantics_loss, offsets_loss, box_loss,
                                   conf_loss)

        loss.backward()
        optimizer.step()

        return {
            'loss': loss.item(),
            'loss_semantics': semantics_loss.item(),
            'loss_offsets': offsets_loss.item(),
            'loss_ssdbox': box_loss.item(),
            'loss_ssdclass': conf_loss.item()
        }

    trainer = Engine(_update)

    checkpoint_handler = ModelCheckpoint(args.output_dir,
                                         'checkpoint',
                                         save_interval=1,
                                         n_saved=10,
                                         require_empty=False,
                                         create_dir=True,
                                         save_as_state_dict=False)
    timer = Timer(average=True)

    # attach running average metrics
    train_metrics = [
        'loss', 'loss_semantics', 'loss_offsets', 'loss_ssdbox',
        'loss_ssdclass'
    ]
    for m in train_metrics:
        transform = partial(lambda x, metric: x[metric], metric=m)
        RunningAverage(output_transform=transform).attach(trainer, m)

    # attach progress bar
    pbar = ProgressBar(persist=True)
    pbar.attach(trainer, metric_names=train_metrics)

    checkpoint = {
        'model': model.state_dict(),
        'epoch': trainer.state.epoch,
        'optimizer': optimizer.state_dict(),
        'multitask': multitask_criterion.state_dict()
    }
    trainer.add_event_handler(event_name=Events.EPOCH_COMPLETED,
                              handler=checkpoint_handler,
                              to_save={'checkpoint': checkpoint})

    timer.attach(trainer,
                 start=Events.EPOCH_STARTED,
                 resume=Events.ITERATION_STARTED,
                 pause=Events.ITERATION_COMPLETED,
                 step=Events.ITERATION_COMPLETED)

    def _inference(engine, batch):
        model.eval()
        with torch.no_grad():
            x, instance, boxes, labels = _prepare_batch(batch)
            loc_preds, conf_preds, semantics, offsets_pred = model(x)
            boxes_preds, labels_preds, scores_preds = box_coder.decode(
                loc_preds, F.softmax(conf_preds, dim=1), score_thresh=0.01)

            semantics_loss = semantics_criterion(semantics, instance)
            offsets_loss = offsets_criterion(offsets_pred, instance)
            box_loss, conf_loss = box_criterion(loc_preds, boxes, conf_preds,
                                                labels)

            semantics_pred = semantics.argmax(dim=1)
            instances = helper.assign_pix2box(semantics_pred, offsets_pred,
                                              boxes_preds, labels_preds)

        return {
            'loss': (semantics_loss, offsets_loss, {
                'box_loss': box_loss,
                'conf_loss': conf_loss
            }),
            'objects':
            (boxes_preds, labels_preds, scores_preds, boxes, labels),
            'semantics':
            semantics_pred,
            'instances':
            instances
        }

    train_evaluator = Engine(_inference)
    Loss(multitask_criterion,
         output_transform=lambda x: x['loss']).attach(train_evaluator, 'loss')
    MeanAveragePrecision(num_classes,
                         output_transform=lambda x: x['objects']).attach(
                             train_evaluator, 'objects')
    IntersectionOverUnion(num_classes,
                          output_transform=lambda x: x['semantics']).attach(
                              train_evaluator, 'semantics')

    evaluator = Engine(_inference)
    Loss(multitask_criterion,
         output_transform=lambda x: x['loss']).attach(evaluator, 'loss')
    MeanAveragePrecision(num_classes,
                         output_transform=lambda x: x['objects']).attach(
                             evaluator, 'objects')
    IntersectionOverUnion(num_classes,
                          output_transform=lambda x: x['semantics']).attach(
                              evaluator, 'semantics')

    @trainer.on(Events.STARTED)
    def initialize(engine):
        if args.resume:
            engine.state.epoch = args.start_epoch

    @trainer.on(Events.EPOCH_COMPLETED)
    def print_times(engine):
        pbar.log_message(
            "Epoch [{}/{}] done. Time per batch: {:.3f}[s]".format(
                engine.state.epoch, engine.state.max_epochs, timer.value()))
        timer.reset()

    @trainer.on(Events.ITERATION_COMPLETED)
    def log_training_loss(engine):
        iteration = (engine.state.iteration - 1) % len(train_loader) + 1
        if iteration % args.log_interval == 0:
            writer.add_scalar("training/loss", engine.state.output['loss'],
                              engine.state.iteration)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_training_results(engine):
        train_evaluator.run(train_loader)
        metrics = train_evaluator.state.metrics
        loss = metrics['loss']
        mean_ap = metrics['objects']
        iou = metrics['semantics']

        pbar.log_message(
            'Training results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}'
            .format(loss, evaluator.state.epochs, evaluator.state.max_epochs,
                    mean_ap, iou * 100.0))

        writer.add_scalar("train-val/loss", loss, engine.state.epoch)
        writer.add_scalar("train-val/mAP", mean_ap, engine.state.epoch)
        writer.add_scalar("train-val/IoU", iou, engine.state.epoch)

    @trainer.on(Events.EPOCH_COMPLETED)
    def log_validation_results(engine):
        evaluator.run(val_loader)
        metrics = evaluator.state.metrics
        loss = metrics['loss']
        mean_ap = metrics['objects']
        iou = metrics['semantics']

        pbar.log_message(
            'Validation results - Epoch: [{}/{}]: Loss: {:.4f}, mAP(50%): {:.1f}, IoU: {:.1f}'
            .format(loss, evaluator.state.epochs, evaluator.state.max_epochs,
                    mean_ap, iou * 100.0))

        writer.add_scalar("validation/loss", loss, engine.state.epoch)
        writer.add_scalar("validation/mAP", mean_ap, engine.state.epoch)
        writer.add_scalar("validation/IoU", iou, engine.state.epoch)

    @trainer.on(Events.EXCEPTION_RAISED)
    def handle_exception(engine, e):
        if isinstance(e, KeyboardInterrupt) and (engine.state.iteration > 1):
            engine.terminate()
            warnings.warn("KeyboardInterrupt caught. Exiting gracefully.")

            checkpoint_handler(engine, {'model_exception': model})
        else:
            raise e

    @trainer.on(Events.COMPLETED)
    def save_final_model(engine):
        checkpoint_handler(engine, {'final': model})

    trainer.run(train_loader, max_epochs=args.epochs)
    writer.close()
Example #8
0
class RetinaNet(nn.Module):
    def __init__(self, backbone='res50', hyps=None):
        super(RetinaNet, self).__init__()
        self.num_classes = int(hyps['num_classes']) + 1
        self.anchor_generator = Anchors(ratios=np.array([0.5, 1, 2]), )
        self.num_anchors = self.anchor_generator.num_anchors
        self.init_backbone(backbone)

        self.fpn = FPN(in_channels_list=self.fpn_in_channels,
                       out_channels=256,
                       top_blocks=LastLevelP6P7(self.fpn_in_channels[-1], 256),
                       use_asff=False)
        self.cls_head = CLSHead(in_channels=256,
                                feat_channels=256,
                                num_stacked=4,
                                num_anchors=self.num_anchors,
                                num_classes=self.num_classes)
        self.reg_head = REGHead(
            in_channels=256,
            feat_channels=256,
            num_stacked=4,
            num_anchors=self.num_anchors,
            num_regress=5  # xywha
        )
        self.loss = IntegratedLoss(func='smooth')
        # self.loss_var = KLLoss()
        self.box_coder = BoxCoder()

    def init_backbone(self, backbone):
        if backbone == 'res34':
            self.backbone = models.resnet34(pretrained=True)
            self.fpn_in_channels = [128, 256, 512]
        elif backbone == 'res50':
            self.backbone = models.resnet50(pretrained=True)
            self.fpn_in_channels = [512, 1024, 2048]
        elif backbone == 'res101':
            self.backbone = models.resnet101(pretrained=True)
            self.fpn_in_channels = [512, 1024, 2048]
        elif backbone == 'res152':
            self.backbone = models.resnet152(pretrained=True)
            self.fpn_in_channels = [512, 1024, 2048]
        elif backbone == 'resnext50':
            self.backbone = models.resnext50_32x4d(pretrained=True)
            self.fpn_in_channels = [512, 1024, 2048]
        else:
            raise NotImplementedError
        del self.backbone.avgpool
        del self.backbone.fc

    def ims_2_features(self, ims):
        c1 = self.backbone.relu(self.backbone.bn1(self.backbone.conv1(ims)))
        c2 = self.backbone.layer1(self.backbone.maxpool(c1))
        c3 = self.backbone.layer2(c2)
        c4 = self.backbone.layer3(c3)
        c5 = self.backbone.layer4(c4)
        #c_i shape: bs,C,H,W
        return [c3, c4, c5]

    def forward(self, ims, gt_boxes=None, test_conf=None, process=None):
        anchors_list, offsets_list, cls_list, var_list = [], [], [], []
        original_anchors = self.anchor_generator(
            ims)  # (bs, num_all_achors, 5)
        anchors_list.append(original_anchors)
        features = self.fpn(self.ims_2_features(ims))
        cls_score = torch.cat([self.cls_head(feature) for feature in features],
                              dim=1)
        bbox_pred = torch.cat([self.reg_head(feature) for feature in features],
                              dim=1)
        bboxes = self.box_coder.decode(anchors_list[-1],
                                       bbox_pred,
                                       mode='xywht').detach()

        if self.training:
            losses = dict()
            bf_weight = self.calc_mining_param(process, 0.3)
            losses['loss_cls'], losses['loss_reg'] = self.loss(cls_score, bbox_pred, anchors_list[-1], bboxes, gt_boxes, \
                                                               md_thres=0.6,
                                                               mining_param=(bf_weight, 1-bf_weight, 5)
                                                              )
            return losses

        else:  # eval() mode
            return self.decoder(ims,
                                anchors_list[-1],
                                cls_score,
                                bbox_pred,
                                test_conf=test_conf)

    def decoder(self,
                ims,
                anchors,
                cls_score,
                bbox_pred,
                thresh=0.6,
                nms_thresh=0.2,
                test_conf=None):
        if test_conf is not None:
            thresh = test_conf
        bboxes = self.box_coder.decode(anchors, bbox_pred, mode='xywht')
        bboxes = clip_boxes(bboxes, ims)
        scores = torch.max(cls_score, dim=2, keepdim=True)[0]
        keep = (scores >= thresh)[0, :, 0]
        if keep.sum() == 0:
            return [torch.zeros(1), torch.zeros(1), torch.zeros(1, 5)]
        scores = scores[:, keep, :]
        anchors = anchors[:, keep, :]
        cls_score = cls_score[:, keep, :]
        bboxes = bboxes[:, keep, :]
        # NMS
        anchors_nms_idx = nms(
            torch.cat([bboxes, scores], dim=2)[0, :, :], nms_thresh)
        nms_scores, nms_class = cls_score[0, anchors_nms_idx, :].max(dim=1)
        output_boxes = torch.cat(
            [bboxes[0, anchors_nms_idx, :], anchors[0, anchors_nms_idx, :]],
            dim=1)
        return [nms_scores, nms_class, output_boxes]

    def freeze_bn(self):
        for layer in self.modules():
            if isinstance(layer, nn.BatchNorm2d):
                layer.eval()

    def calc_mining_param(self, process, alpha):
        if process < 0.1:
            bf_weight = 1.0
        elif process > 0.3:
            bf_weight = alpha
        else:
            bf_weight = 5 * (alpha - 1) * process + 1.5 - 0.5 * alpha
        return bf_weight
Example #9
0
class IntegratedLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, func = 'smooth'):
        super(IntegratedLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.box_coder = BoxCoder()
        if func == 'smooth':
            self.criteron = smooth_l1_loss
        elif func == 'mse':
            self.criteron = F.mse_loss
        elif func == 'balanced':
            self.criteron = balanced_l1_loss
            
    def forward(self, classifications, regressions, anchors, annotations,iou_thres=0.5):
        
        cls_losses = []
        reg_losses = []
        batch_size = classifications.shape[0]
        all_pred_boxes = self.box_coder.decode(anchors, regressions, mode='xywht')
        for j in range(batch_size):
            classification = classifications[j, :, :]
            regression = regressions[j, :, :]
            bbox_annotation = annotations[j, :, :]
            bbox_annotation = bbox_annotation[bbox_annotation[:, -1] != -1]
            pred_boxes = all_pred_boxes[j, :, :]
            if bbox_annotation.shape[0] == 0:
                cls_losses.append(torch.tensor(0).float().cuda())
                reg_losses.append(torch.tensor(0).float().cuda())
                continue
            classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
            indicator = bbox_overlaps(
                min_area_square(anchors[j, :, :]),
                min_area_square(bbox_annotation[:, :-1])
            )
            ious = rbox_overlaps(
                anchors[j, :, :].cpu().numpy(),
                bbox_annotation[:, :-1].cpu().numpy(),
                indicator.cpu().numpy(),
                thresh=1e-1
            )
            if not torch.is_tensor(ious):
                ious = torch.from_numpy(ious).cuda()
            
            iou_max, iou_argmax = torch.max(ious, dim=1)
           
            positive_indices = torch.ge(iou_max, iou_thres)

            max_gt, argmax_gt = ious.max(0) 
            if (max_gt < iou_thres).any():
                positive_indices[argmax_gt[max_gt < iou_thres]]=1
              
            # cls loss
            cls_targets = (torch.ones(classification.shape) * -1).cuda()
            cls_targets[torch.lt(iou_max, iou_thres - 0.1), :] = 0
            num_positive_anchors = positive_indices.sum()
            assigned_annotations = bbox_annotation[iou_argmax, :]
            cls_targets[positive_indices, :] = 0
            cls_targets[positive_indices, assigned_annotations[positive_indices, -1].long()] = 1
            alpha_factor = torch.ones(cls_targets.shape).cuda() * self.alpha
            alpha_factor = torch.where(torch.eq(cls_targets, 1.), alpha_factor, 1. - alpha_factor)
            focal_weight = torch.where(torch.eq(cls_targets, 1.), 1. - classification, classification)
            focal_weight = alpha_factor * torch.pow(focal_weight, self.gamma)
            bin_cross_entropy = -(cls_targets * torch.log(classification+1e-6) + (1.0 - cls_targets) * torch.log(1.0 - classification+1e-6))
            cls_loss = focal_weight * bin_cross_entropy 
            cls_loss = torch.where(torch.ne(cls_targets, -1.0), cls_loss, torch.zeros(cls_loss.shape).cuda())
            cls_losses.append(cls_loss.sum() / torch.clamp(num_positive_anchors.float(), min=1.0))
            # reg loss
            if positive_indices.sum() > 0:
                all_rois = anchors[j, positive_indices, :]
                gt_boxes = assigned_annotations[positive_indices, :]
                reg_targets = self.box_coder.encode(all_rois, gt_boxes)
                reg_loss = self.criteron(regression[positive_indices, :], reg_targets)
                reg_losses.append(reg_loss)

                if not torch.isfinite(reg_loss) :
                    import ipdb; ipdb.set_trace()
            else:
                reg_losses.append(torch.tensor(0).float().cuda())
        loss_cls = torch.stack(cls_losses).mean(dim=0, keepdim=True)
        loss_reg = torch.stack(reg_losses).mean(dim=0, keepdim=True)
        return loss_cls, loss_reg
Example #10
0
class RetinaNet(nn.Module):
    def __init__(self, backbone='res50', hyps=None):
        super(RetinaNet, self).__init__()
        self.num_classes = int(
            hyps['num_classes']
        ) + 1  #这里的num-class由分类级别决定,hrsc_dataset.py文件中分三个级别,所以训练时的level也需要改
        self.anchor_generator = Anchors(ratios=np.array([0.5, 1, 2]), )
        self.num_anchors = self.anchor_generator.num_anchors
        self.init_backbone(backbone)

        self.fpn = FPN(in_channels_list=self.fpn_in_channels,
                       out_channels=256,
                       top_blocks=LastLevelP6P7(self.fpn_in_channels[-1], 256),
                       use_asff=False)
        self.cls_head = CLSHead(in_channels=256,
                                feat_channels=256,
                                num_stacked=4,
                                num_anchors=self.num_anchors,
                                num_classes=self.num_classes)
        self.reg_head = REGHead(
            in_channels=256,
            feat_channels=256,
            num_stacked=4,
            num_anchors=self.num_anchors,
            num_regress=5  # xywha
        )
        self.loss = IntegratedLoss(func='smooth')  #计算损失函数
        # self.loss_var = KLLoss()
        self.box_coder = BoxCoder()  #计算回归值

    def init_backbone(self, backbone):
        if backbone == 'res34':
            self.backbone = models.resnet34(pretrained=True)
            self.fpn_in_channels = [128, 256, 512]
        elif backbone == 'res50':
            self.backbone = models.resnet50(pretrained=True)
            self.fpn_in_channels = [512, 1024, 2048]
        elif backbone == 'res101':
            self.backbone = models.resnet101(pretrained=True)
            self.fpn_in_channels = [512, 1024, 2048]
        elif backbone == 'res152':
            self.backbone = models.resnet152(pretrained=True)
            self.fpn_in_channels = [512, 1024, 2048]
        elif backbone == 'resnext50':
            self.backbone = models.resnext50_32x4d(pretrained=True)
            self.fpn_in_channels = [512, 1024, 2048]
        else:
            raise NotImplementedError
        del self.backbone.avgpool
        del self.backbone.fc

    def ims_2_features(self, ims):
        c1 = self.backbone.relu(self.backbone.bn1(self.backbone.conv1(ims)))
        c2 = self.backbone.layer1(self.backbone.maxpool(c1))
        c3 = self.backbone.layer2(c2)
        c4 = self.backbone.layer3(c3)
        c5 = self.backbone.layer4(c4)
        #c_i shape: bs,C,H,W
        return [c3, c4, c5]

    def forward(self, ims, gt_boxes=None, test_conf=None, process=None):
        anchors_list, offsets_list, cls_list, var_list = [], [], [], []
        original_anchors = self.anchor_generator(
            ims)  # 尺度=(batchsize, num_all_achors, 5)
        anchors_list.append(original_anchors)

        # 经过网络计算的特征图---两个---一个用于分类一个用于回归
        features = self.fpn(self.ims_2_features(ims))
        cls_score = torch.cat([self.cls_head(feature) for feature in features],
                              dim=1)
        bbox_pred = torch.cat([self.reg_head(feature) for feature in features],
                              dim=1)

        # 获取回归的box
        bboxes = self.box_coder.decode(anchors_list[-1],
                                       bbox_pred,
                                       mode='xywht').detach()

        if self.training:  # 如果是训练,则返回损失---分类损失和回归损失---通过loss函数计算
            losses = dict()
            bf_weight = self.calc_mining_param(
                process, 0.3)  # 逐步调整输入iou对匹配度的影响,也就是跟随训练进度来调整alpha
            # 所有框分类得分特征图(类别)、预测框特征图(偏移)、原始anchor、预测框回归值(真实坐标角度)、标签框(最后一位应该代表的是类别,只有1,否则-1)
            losses['loss_cls'], losses['loss_reg'] = self.loss(cls_score, bbox_pred, anchors_list[-1], bboxes, gt_boxes, \
                                                               md_thres=0.6,
                                                               mining_param=(bf_weight, 1-bf_weight, 5)
                                                              )
            return losses

        else:  # eval() mode 如果不是训练---则返回的是[nms_scores, nms_class, output_boxes]
            return self.decoder(ims,
                                anchors_list[-1],
                                cls_score,
                                bbox_pred,
                                test_conf=test_conf)

    # decode解码是返回检测框(两个坐标)以及角度---encoder编码是计算偏移-用于计算损失
    def decoder(self,
                ims,
                anchors,
                cls_score,
                bbox_pred,
                thresh=0.6,
                nms_thresh=0.2,
                test_conf=None):
        if test_conf is not None:
            thresh = test_conf
        bboxes = self.box_coder.decode(anchors, bbox_pred,
                                       mode='xywht')  #返回真实的预测框,两个坐标以及角度
        bboxes = clip_boxes(bboxes, ims)
        scores = torch.max(cls_score, dim=2, keepdim=True)[0]
        keep = (scores >= thresh)[0, :, 0]
        if keep.sum() == 0:
            return [torch.zeros(1), torch.zeros(1), torch.zeros(1, 5)]
        scores = scores[:, keep, :]
        anchors = anchors[:, keep, :]
        cls_score = cls_score[:, keep, :]
        bboxes = bboxes[:, keep, :]
        # NMS
        anchors_nms_idx = nms(
            torch.cat([bboxes, scores], dim=2)[0, :, :], nms_thresh)
        nms_scores, nms_class = cls_score[0, anchors_nms_idx, :].max(dim=1)
        output_boxes = torch.cat(
            [bboxes[0, anchors_nms_idx, :], anchors[0, anchors_nms_idx, :]],
            dim=1)
        return [nms_scores, nms_class, output_boxes]

    def freeze_bn(self):
        for layer in self.modules():
            if isinstance(layer, nn.BatchNorm2d):
                layer.eval()

    def calc_mining_param(self, process, alpha):  # 逐步调整输入iou对匹配度的影响
        if process < 0.1:
            bf_weight = 1.0
        elif process > 0.3:
            bf_weight = alpha
        else:
            bf_weight = 5 * (alpha - 1) * process + 1.5 - 0.5 * alpha
        return bf_weight
Example #11
0
class IntegratedLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, func='smooth'):
        super(IntegratedLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.box_coder = BoxCoder()
        if func == 'smooth':
            self.criteron = smooth_l1_loss
        elif func == 'mse':
            self.criteron = F.mse_loss
        elif func == 'balanced':
            self.criteron = balanced_l1_loss

    def forward(self, classifications, regressions, anchors, refined_achors, annotations, \
                md_thres=0.5, mining_param=(1, 0., -1), ref=False):

        das = True
        cls_losses = []
        reg_losses = []
        batch_size = classifications.shape[0]
        alpha, beta, var = mining_param
        #         import ipdb;ipdb.set_trace()
        for j in range(batch_size):
            classification = classifications[j, :, :]
            regression = regressions[j, :, :]
            bbox_annotation = annotations[j, :, :]
            bbox_annotation = bbox_annotation[bbox_annotation[:, -1] != -1]
            if bbox_annotation.shape[0] == 0:
                cls_losses.append(torch.tensor(0).float().cuda())
                reg_losses.append(torch.tensor(0).float().cuda())
                continue
            classification = torch.clamp(classification, 1e-4, 1.0 - 1e-4)
            sa = rbbx_overlaps(
                xyxy2xywh_a(anchors[j, :, :].cpu().numpy()),
                xyxy2xywh_a(bbox_annotation[:, :-1].cpu().numpy()),
            )
            if not torch.is_tensor(sa):
                # import ipdb;ipdb.set_trace()
                sa = torch.from_numpy(sa).cuda()
            if var != -1:
                fa = rbbx_overlaps(
                    xyxy2xywh_a(refined_achors[j, :, :].cpu().numpy()),
                    xyxy2xywh_a(bbox_annotation[:, :-1].cpu().numpy()),
                )
                if not torch.is_tensor(fa):
                    fa = torch.from_numpy(fa).cuda()

                if var == 0:
                    md = abs((alpha * sa + beta * fa))
                else:
                    md = abs((alpha * sa + beta * fa) - abs(fa - sa)**var)
            else:
                das = False
                md = sa

            iou_max, iou_argmax = torch.max(md, dim=1)

            positive_indices = torch.ge(iou_max, md_thres)

            max_gt, argmax_gt = md.max(0)
            # import ipdb;ipdb.set_trace(context = 15)
            if (max_gt < md_thres).any():
                positive_indices[argmax_gt[max_gt < md_thres]] = 1

            # matching-weight
            if das:
                pos = md[positive_indices]
                pos_mask = torch.ge(pos, md_thres)
                max_pos, armmax_pos = pos.max(0)
                nt = md.shape[1]
                for gt_idx in range(nt):
                    pos_mask[armmax_pos[gt_idx], gt_idx] = 1
                comp = torch.where(pos_mask, (1 - max_pos).repeat(len(pos), 1),
                                   pos)
                matching_weight = comp + pos
            # import ipdb; ipdb.set_trace(context = 15)

            # cls loss
            cls_targets = (torch.ones(classification.shape) * -1).cuda()
            cls_targets[torch.lt(iou_max, md_thres - 0.1), :] = 0
            num_positive_anchors = positive_indices.sum()
            assigned_annotations = bbox_annotation[iou_argmax, :]
            cls_targets[positive_indices, :] = 0
            cls_targets[positive_indices,
                        assigned_annotations[positive_indices, -1].long()] = 1
            alpha_factor = torch.ones(cls_targets.shape).cuda() * self.alpha
            alpha_factor = torch.where(torch.eq(cls_targets, 1.), alpha_factor,
                                       1. - alpha_factor)
            focal_weight = torch.where(torch.eq(cls_targets, 1.),
                                       1. - classification, classification)
            focal_weight = alpha_factor * torch.pow(focal_weight, self.gamma)
            bin_cross_entropy = -(
                cls_targets * torch.log(classification + 1e-6) +
                (1.0 - cls_targets) * torch.log(1.0 - classification + 1e-6))
            if das:
                soft_weight = (torch.zeros(classification.shape)).cuda()
                soft_weight = torch.where(torch.eq(cls_targets, 0.),
                                          torch.ones_like(cls_targets),
                                          soft_weight)
                soft_weight[positive_indices, assigned_annotations[
                    positive_indices,
                    -1].long()] = (matching_weight.max(1)[0] + 1)
                cls_loss = focal_weight * bin_cross_entropy * soft_weight
            else:
                cls_loss = focal_weight * bin_cross_entropy
            cls_loss = torch.where(torch.ne(cls_targets, -1.0), cls_loss,
                                   torch.zeros(cls_loss.shape).cuda())
            cls_losses.append(
                cls_loss.sum() /
                torch.clamp(num_positive_anchors.float(), min=1.0))
            # reg loss
            if positive_indices.sum() > 0:
                all_rois = anchors[j, positive_indices, :]
                gt_boxes = assigned_annotations[positive_indices, :]
                reg_targets = self.box_coder.encode(all_rois, gt_boxes)
                if das:
                    reg_loss = self.criteron(regression[positive_indices, :],
                                             reg_targets,
                                             weight=matching_weight)
                else:
                    reg_loss = self.criteron(regression[positive_indices, :],
                                             reg_targets)
                reg_losses.append(reg_loss)

                if not torch.isfinite(reg_loss):
                    import ipdb
                    ipdb.set_trace()
                k = 1
            else:
                reg_losses.append(torch.tensor(0).float().cuda())
        loss_cls = torch.stack(cls_losses).mean(dim=0, keepdim=True)
        loss_reg = torch.stack(reg_losses).mean(dim=0, keepdim=True)
        return loss_cls, loss_reg
Example #12
0
class IntegratedLoss(nn.Module):
    def __init__(self, alpha=0.25, gamma=2.0, func='smooth'):
        super(IntegratedLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.box_coder = BoxCoder()
        if func == 'smooth':  # 损失函数转折处变得更加平滑
            self.criteron = smooth_l1_loss
        elif func == 'mse':
            self.criteron = F.mse_loss
        elif func == 'balanced':
            self.criteron = balanced_l1_loss

    def forward(self, classifications, regressions, anchors, refined_achors, annotations, \
                md_thres=0.5, mining_param=(1, 0., -1), ref=False):

        das = True
        cls_losses = []
        reg_losses = []
        batch_size = classifications.shape[0]
        alpha, beta, var = mining_param
        #         import ipdb;ipdb.set_trace()
        for j in range(batch_size):  #迭代每一张输入的图片
            # 分类和回归的特征图
            classification = classifications[
                j, :, :]  # classification的维度=batch-size,box—num,概率
            regression = regressions[
                j, :, :]  # regression的维度=batch-size,box-num,回归坐标值,包含角度
            # 真实的标注数据
            bbox_annotation = annotations[j, :, :]
            bbox_annotation = bbox_annotation[bbox_annotation[:, -1] !=
                                              -1]  #类别不为-1的框
            if bbox_annotation.shape[0] == 0:  # 如果目标数量为0
                cls_losses.append(torch.tensor(0).float().cuda())
                reg_losses.append(torch.tensor(0).float().cuda())
                continue
            classification = torch.clamp(classification, 1e-4,
                                         1.0 - 1e-4)  #修剪在规定的范围之内,框的數量×類別數(2)

            # 下面计算匹配度,sa空间对齐,和输入ROI有关
            sa = rbbx_overlaps(
                xyxy2xywh_a(anchors[j, :, :].cpu().numpy()),
                xyxy2xywh_a(bbox_annotation[:, :-1].cpu().numpy()),
            )
            if not torch.is_tensor(sa):
                # import ipdb;ipdb.set_trace()
                sa = torch.from_numpy(sa).cuda()
            if var != -1:
                # 下面计算特征对齐fa,是關於GT和回归之间的
                fa = rbbx_overlaps(
                    xyxy2xywh_a(refined_achors[j, :, :].cpu().numpy()),
                    xyxy2xywh_a(bbox_annotation[:, :-1].cpu().numpy()),
                )
                if not torch.is_tensor(fa):
                    fa = torch.from_numpy(fa).cuda()
                # 匹配度计算---空间对齐-特征对齐-惩罚项
                if var == 0:
                    md = abs((alpha * sa + beta * fa))
                else:  # 匹配度计算---空间对齐-特征对齐-惩罚项
                    md = abs((alpha * sa + beta * fa) - abs(fa - sa)**var)
            else:
                das = False
                md = sa

            # 然后将所有目标压缩,我们不关注目标,只关注我们的anchor box和位置的目标之间的iou,然后就取最大值,所以维度是anchor box
            iou_max, iou_argmax = torch.max(
                md,
                dim=1)  #应该是对于每一个gt中的目标,所有的anchor box都与其进行匹配,所以是anchor数量×目标数
            # 但是这里我们不关注每一个具体的目标类别-匹配度更加关注于是否阳性阴性样本-也就是重合度如何

            # 通过匹配性阈值来选取阳性样本,正样本-True的位置
            positive_indices = torch.ge(iou_max, md_thres)

            # 对于所有的anchor box计算匹配度后,max-gt返回的是对应目标数量的匹配度
            # argmax-gt返回的是相应的位置,也就是哪个预选框(其实下面这行代码就是选出和每一个gt目标匹配度最高的预选框)
            max_gt, argmax_gt = md.max(0)
            # import ipdb;ipdb.set_trace(context = 15)
            if (max_gt < md_thres).any():  # 都不及匹配度阈值
                positive_indices[argmax_gt[
                    max_gt < md_thres]] = 1  # 正样本中,超过阈值,最大阈值 = 1

            # matching-weight
            if das:
                pos = md[positive_indices]
                pos_mask = torch.ge(pos, md_thres)
                max_pos, armmax_pos = pos.max(
                    0)  # 这里就取得了阳性样本中的最大匹配值,以及其索引(位置),用于后面计算补偿因子
                nt = md.shape[1]  #gt中的目标数量
                for gt_idx in range(nt):
                    pos_mask[
                        armmax_pos[gt_idx],
                        gt_idx] = 1  # 这里的pos_mask对应维度 框数量×目标数量,也就是选出了每一个目标最大匹配度所对应的框,也就是一行一个True
                comp = torch.where(pos_mask, (1 - max_pos).repeat(len(pos), 1),
                                   pos)
                # 对于拥有最大的匹配度的阳性样本-其计算损失的补偿因子(权重w)=1
                matching_weight = comp + pos  # 然后再对其他的阳性样本进行补偿
            # import ipdb; ipdb.set_trace(context = 15)

            # cls loss
            cls_targets = (torch.ones(classification.shape) * -1).cuda()
            # 逐元素比较,小于md-thres -0.1就置零
            cls_targets[torch.lt(iou_max, md_thres - 0.1), :] = 0

            num_positive_anchors = positive_indices.sum()  #阳性样本数量
            assigned_annotations = bbox_annotation[iou_argmax, :]
            cls_targets[positive_indices, :] = 0
            cls_targets[positive_indices,
                        assigned_annotations[positive_indices, -1].long()] = 1

            alpha_factor = torch.ones(cls_targets.shape).cuda() * self.alpha

            # torch.where 类似于条件运算符,符合的保留第一个,否则第二个
            alpha_factor = torch.where(torch.eq(cls_targets, 1.), alpha_factor,
                                       1. - alpha_factor)
            focal_weight = torch.where(torch.eq(cls_targets, 1.),
                                       1. - classification, classification)
            focal_weight = alpha_factor * torch.pow(focal_weight, self.gamma)
            # 交叉熵损失函数
            bin_cross_entropy = -(
                cls_targets * torch.log(classification + 1e-6) +
                (1.0 - cls_targets) * torch.log(1.0 - classification + 1e-6))
            if das:
                soft_weight = (torch.zeros(classification.shape)).cuda()
                soft_weight = torch.where(torch.eq(cls_targets, 0.),
                                          torch.ones_like(cls_targets),
                                          soft_weight)
                soft_weight[positive_indices, assigned_annotations[
                    positive_indices,
                    -1].long()] = (matching_weight.max(1)[0] + 1)
                # focal-loss加一个权重,该权重关注阳性样本
                cls_loss = focal_weight * bin_cross_entropy * soft_weight
            else:
                cls_loss = focal_weight * bin_cross_entropy
            # 这里计算不等于-1处的损失
            cls_loss = torch.where(torch.ne(cls_targets, -1.0), cls_loss,
                                   torch.zeros(cls_loss.shape).cuda())
            # 好像这里的分类损失仅仅计算了补偿因子加权的损失
            cls_losses.append(
                cls_loss.sum() /
                torch.clamp(num_positive_anchors.float(), min=1.0))

            # reg loss---回归的损失函数使用smoothL1损失-在转折处更加平滑
            if positive_indices.sum() > 0:
                all_rois = anchors[j, positive_indices, :]  #阳性样本
                gt_boxes = assigned_annotations[
                    positive_indices, :]  #我们仅仅对阳性样本计算补偿加权的回归损失,所以这里按照positive-indices
                # reg-targets返回targets_dx, targets_dy, targets_dw, targets_dh, targets_dt用于计算回归损失
                reg_targets = self.box_coder.encode(all_rois, gt_boxes)
                if das:  #加入补偿因子的损失计算, reg_target是偏移量,用于计算损失
                    reg_loss = self.criteron(regression[positive_indices, :],
                                             reg_targets,
                                             weight=matching_weight)
                else:
                    reg_loss = self.criteron(regression[positive_indices, :],
                                             reg_targets)
                reg_losses.append(reg_loss)

                if not torch.isfinite(reg_loss):
                    import ipdb
                    ipdb.set_trace()
                k = 1
            else:
                reg_losses.append(torch.tensor(0).float().cuda())

        loss_cls = torch.stack(cls_losses).mean(dim=0, keepdim=True)  # 拼接后求均值
        loss_reg = torch.stack(reg_losses).mean(dim=0, keepdim=True)  # 拼接后求均值
        return loss_cls, loss_reg
class STELA(nn.Module):
    def __init__(self, backbone='res50', num_classes=2, num_refining=1):
        super(STELA, self).__init__()
        self.anchor_generator = Anchors()
        self.num_anchors = self.anchor_generator.num_anchors
        self.init_backbone(backbone)
        self.fpn = FPN(in_channels_list=self.fpn_in_channels,
                       out_channels=256,
                       top_blocks=LastLevelP6P7(self.fpn_in_channels[-1], 256))
        self.cls_head = CLSHead(in_channels=256,
                                feat_channels=256,
                                num_stacked=1,
                                num_anchors=self.num_anchors,
                                num_classes=num_classes)
        self.reg_head = REGHead(in_channels=256,
                                feat_channels=256,
                                num_stacked=1,
                                num_anchors=self.num_anchors,
                                num_regress=5)
        self.num_refining = num_refining
        if self.num_refining > 0:
            self.ref_heads = nn.ModuleList([
                REGHead(in_channels=256,
                        feat_channels=256,
                        num_stacked=1,
                        num_anchors=self.num_anchors,
                        num_regress=5) for _ in range(self.num_refining)
            ])
            self.loss_ref = RegressLoss(func='smooth')
        self.loss_cls = FocalLoss()
        self.loss_reg = RegressLoss(func='smooth')
        self.box_coder = BoxCoder()

    def init_backbone(self, backbone):
        if backbone == 'res34':
            self.backbone = models.resnet34(pretrained=True)
            self.fpn_in_channels = [128, 256, 512]
        elif backbone == 'res50':
            self.backbone = models.resnet50(pretrained=True)
            self.fpn_in_channels = [512, 1024, 2048]
        elif backbone == 'resnext50':
            self.backbone = models.resnext50_32x4d(pretrained=True)
            self.fpn_in_channels = [512, 1024, 2048]
        else:
            raise NotImplementedError
        del self.backbone.avgpool
        del self.backbone.fc

    def ims_2_features(self, ims):
        c1 = self.backbone.relu(self.backbone.bn1(self.backbone.conv1(ims)))
        c2 = self.backbone.layer1(self.backbone.maxpool(c1))
        c3 = self.backbone.layer2(c2)
        c4 = self.backbone.layer3(c3)
        c5 = self.backbone.layer4(c4)
        return [c3, c4, c5]

    def forward(self, ims, gt_boxes=None):
        anchors_list, offsets_list = [], []
        original_anchors = self.anchor_generator(ims)
        anchors_list.append(original_anchors)
        features = self.fpn(self.ims_2_features(ims))

        # anchor refining
        if self.num_refining > 0:
            for i in range(self.num_refining):
                bbox_pred = torch.cat(
                    [self.ref_heads[i](feature) for feature in features],
                    dim=1)
                refined_anchors = self.box_coder.decode(anchors_list[-1],
                                                        bbox_pred,
                                                        mode='wht').detach()
                anchors_list.append(refined_anchors)
                offsets_list.append(bbox_pred)

        cls_score = torch.cat([self.cls_head(feature) for feature in features],
                              dim=1)
        bbox_pred = torch.cat([self.reg_head(feature) for feature in features],
                              dim=1)
        if self.training:
            losses = dict()
            if self.num_refining > 0:
                ref_losses = []
                for i in range(self.num_refining):
                    ref_losses.append(
                        self.loss_ref(offsets_list[i],
                                      anchors_list[i],
                                      gt_boxes,
                                      iou_thresh=(0.3 + i * 0.1)))
                losses['loss_ref'] = torch.stack(ref_losses).mean(dim=0,
                                                                  keepdim=True)
            losses['loss_cls'] = self.loss_cls(cls_score,
                                               anchors_list[-1],
                                               gt_boxes,
                                               iou_thresh=0.5)
            losses['loss_reg'] = self.loss_reg(bbox_pred,
                                               anchors_list[-1],
                                               gt_boxes,
                                               iou_thresh=0.5)
            return losses
        else:
            return self.decoder(ims, anchors_list[-1], cls_score, bbox_pred)

    def decoder(self,
                ims,
                anchors,
                cls_score,
                bbox_pred,
                thresh=0.3,
                nms_thresh=0.3):
        bboxes = self.box_coder.decode(anchors, bbox_pred, mode='xywht')
        bboxes = clip_boxes(bboxes, ims)
        scores = torch.max(cls_score, dim=2, keepdim=True)[0]
        keep = (scores >= thresh)[0, :, 0]
        if keep.sum() == 0:
            return [torch.zeros(1), torch.zeros(1), torch.zeros(1, 5)]
        scores = scores[:, keep, :]
        anchors = anchors[:, keep, :]
        cls_score = cls_score[:, keep, :]
        bboxes = bboxes[:, keep, :]
        anchors_nms_idx = nms(
            torch.cat([bboxes, scores], dim=2)[0, :, :], nms_thresh)
        nms_scores, nms_class = cls_score[0, anchors_nms_idx, :].max(dim=1)
        output_boxes = torch.cat(
            [bboxes[0, anchors_nms_idx, :], anchors[0, anchors_nms_idx, :]],
            dim=1)
        return [nms_scores, nms_class, output_boxes]

    def freeze_bn(self):
        for layer in self.modules():
            if isinstance(layer, nn.BatchNorm2d):
                layer.eval()