Пример #1
0
    def forward(self, x, target=None):
        # backbone
        C_4, C_5 = self.backbone(x)

        # detection head
        # multi scale feature map fusion
        C_5 = self.conv_set_2(C_5)
        C_5_up = F.interpolate(self.conv_1x1_2(C_5),
                               scale_factor=2.0,
                               mode='bilinear',
                               align_corners=True)

        C_4 = torch.cat([C_4, C_5_up], dim=1)
        C_4 = self.conv_set_1(C_4)

        # head
        # s = 32
        C_5 = self.extra_conv_2(C_5)
        pred_2 = self.pred_2(C_5)

        # s = 16
        pred_1 = self.pred_1(C_4)

        preds = [pred_1, pred_2]
        total_conf_pred = []
        total_cls_pred = []
        total_txtytwth_pred = []
        B = HW = 0
        for pred in preds:
            B_, abC_, H_, W_ = pred.size()

            # [B, anchor_n * C, H, W] -> [B, H, W, anchor_n * C] -> [B, H*W, anchor_n*C]
            pred = pred.permute(0, 2, 3,
                                1).contiguous().view(B_, H_ * W_, abC_)

            # Divide prediction to obj_pred, xywh_pred and cls_pred
            # [B, H*W*anchor_n, 1]
            conf_pred = pred[:, :, :1 * self.anchor_number].contiguous().view(
                B_, H_ * W_ * self.anchor_number, 1)
            # [B, H*W*anchor_n, num_cls]
            cls_pred = pred[:, :,
                            1 * self.anchor_number:(1 + self.num_classes) *
                            self.anchor_number].contiguous().view(
                                B_, H_ * W_ * self.anchor_number,
                                self.num_classes)
            # [B, H*W*anchor_n, 4]
            txtytwth_pred = pred[:, :, (1 + self.num_classes) *
                                 self.anchor_number:].contiguous()

            total_conf_pred.append(conf_pred)
            total_cls_pred.append(cls_pred)
            total_txtytwth_pred.append(txtytwth_pred)
            B = B_
            HW += H_ * W_

        conf_pred = torch.cat(total_conf_pred, 1)
        cls_pred = torch.cat(total_cls_pred, 1)
        txtytwth_pred = torch.cat(total_txtytwth_pred, 1)

        # test
        if not self.trainable:
            txtytwth_pred = txtytwth_pred.view(B, HW, self.anchor_number, 4)
            with torch.no_grad():
                # batch size = 1
                all_obj = torch.sigmoid(conf_pred)[
                    0]  # 0 is because that these is only 1 batch.
                all_bbox = torch.clamp(
                    (self.decode_boxes(txtytwth_pred) / self.scale_torch)[0],
                    0., 1.)
                all_class = (torch.softmax(cls_pred[0, :, :], dim=1) * all_obj)
                # separate box pred and class conf
                all_obj = all_obj.to('cpu').numpy()
                all_class = all_class.to('cpu').numpy()
                all_bbox = all_bbox.to('cpu').numpy()

                bboxes, scores, cls_inds = self.postprocess(
                    all_bbox, all_class)

                # print(len(all_boxes))
                return bboxes, scores, cls_inds

        else:
            txtytwth_pred = txtytwth_pred.view(B, HW, self.anchor_number, 4)
            # decode bbox, and remember to cancel its grad since we set iou as the label of objectness.
            with torch.no_grad():
                x1y1x2y2_pred = (self.decode_boxes(txtytwth_pred) /
                                 self.scale_torch).view(-1, 4)

            txtytwth_pred = txtytwth_pred.view(B, -1, 4)

            x1y1x2y2_gt = target[:, :, 7:].view(-1, 4)

            # compute iou
            iou = tools.iou_score(x1y1x2y2_pred, x1y1x2y2_gt).view(B, -1, 1)
            # print(iou.min(), iou.max())

            # we set iou between pred bbox and gt bbox as conf label.
            # [obj, cls, txtytwth, x1y1x2y2] -> [conf, obj, cls, txtytwth]
            target = torch.cat([iou, target[:, :, :7]], dim=2)

            conf_loss, cls_loss, txtytwth_loss, total_loss = tools.loss(
                pred_conf=conf_pred,
                pred_cls=cls_pred,
                pred_txtytwth=txtytwth_pred,
                label=target,
                num_classes=self.num_classes,
                obj_loss_f='mse')

            return conf_loss, cls_loss, txtytwth_loss, total_loss
Пример #2
0
    def forward(self, x, target=None):
        # backbone
        C4, C5 = self.backbone(x)

        # head
        C5 = self.convsets_1(C5)

        # route from 16th layer in darknet
        C4 = self.reorg(self.route_layer(C4))

        # route concatenate
        C5 = torch.cat([C4, C5], dim=1)
        C5 = self.convsets_2(C5)
        prediction = self.pred_(C5)

        B, abC, H, W = prediction.size()

        # [B, anchor_n * C, N, M] -> [B, N, M, anchor_n * C] -> [B, N*M, anchor_n*C]
        prediction = prediction.permute(0, 2, 3, 1).contiguous().view(B, H*W, abC)

        # Divide prediction to obj_pred, xywh_pred and cls_pred   
        # [B, H*W*anchor_n, 1]
        conf_pred = prediction[:, :, :1 * self.anchor_number].contiguous().view(B, H*W*self.anchor_number, 1)
        # [B, H*W, anchor_n, num_cls]
        cls_pred = prediction[:, :, 1 * self.anchor_number : (1 + self.num_classes) * self.anchor_number].contiguous().view(B, H*W*self.anchor_number, self.num_classes)
        # [B, H*W, anchor_n, 4]
        txtytwth_pred = prediction[:, :, (1 + self.num_classes) * self.anchor_number:].contiguous()
        
        # test
        if not self.trainable:
            txtytwth_pred = txtytwth_pred.view(B, H*W, self.anchor_number, 4)
            with torch.no_grad():
                # batch size = 1                
                all_obj = torch.sigmoid(conf_pred)[0]           # 0 is because that these is only 1 batch.
                all_bbox = torch.clamp((self.decode_boxes(txtytwth_pred) / self.scale_torch)[0], 0., 1.)
                all_class = (torch.softmax(cls_pred[0, :, :], 1) * all_obj)
                # separate box pred and class conf
                all_obj = all_obj.to('cpu').numpy()
                all_class = all_class.to('cpu').numpy()
                all_bbox = all_bbox.to('cpu').numpy()

                bboxes, scores, cls_inds = self.postprocess(all_bbox, all_class)

                return bboxes, scores, cls_inds

        else:
            txtytwth_pred = txtytwth_pred.view(B, H*W, self.anchor_number, 4)
            # decode bbox, and remember to cancel its grad since we set iou as the label of objectness.
            with torch.no_grad():
                x1y1x2y2_pred = (self.decode_boxes(txtytwth_pred) / self.scale_torch).view(-1, 4)

            txtytwth_pred = txtytwth_pred.view(B, H*W*self.anchor_number, 4)

            x1y1x2y2_gt = target[:, :, 7:].view(-1, 4)

            # compute iou
            iou = tools.iou_score(x1y1x2y2_pred, x1y1x2y2_gt).view(B, H*W*self.anchor_number, 1)
            # print(iou.min(), iou.max())

            # we set iou between pred bbox and gt bbox as conf label. 
            # [obj, cls, txtytwth, x1y1x2y2] -> [conf, obj, cls, txtytwth]
            target = torch.cat([iou, target[:, :, :7]], dim=2)

            conf_loss, cls_loss, txtytwth_loss, total_loss = tools.loss(pred_conf=conf_pred, pred_cls=cls_pred,
                                                                        pred_txtytwth=txtytwth_pred,
                                                                        label=target,
                                                                        num_classes=self.num_classes)

            return conf_loss, cls_loss, txtytwth_loss, total_loss
Пример #3
0
    def forward(self, x, target=None):
        # backbone
        c3, c4, c5 = self.backbone(x)

        # neck
        c5 = self.spp(c5)

        # FPN + PAN
        # head
        c6 = self.head_conv_0(c5)
        c7 = self.head_upsample_0(c6)  # s32->s16
        c8 = torch.cat([c7, c4], dim=1)
        c9 = self.head_csp_0(c8)
        # P3/8
        c10 = self.head_conv_1(c9)
        c11 = self.head_upsample_1(c10)  # s16->s8
        c12 = torch.cat([c11, c3], dim=1)
        c13 = self.head_csp_1(c12)  # to det
        # p4/16
        c14 = self.head_conv_2(c13)
        c15 = torch.cat([c14, c10], dim=1)
        c16 = self.head_csp_2(c15)  # to det
        # p5/32
        c17 = self.head_conv_3(c16)
        c18 = torch.cat([c17, c6], dim=1)
        c19 = self.head_csp_3(c18)  # to det

        # det
        pred_s = self.head_det_1(c13)
        pred_m = self.head_det_2(c16)
        pred_l = self.head_det_3(c19)

        preds = [pred_s, pred_m, pred_l]
        total_conf_pred = []
        total_cls_pred = []
        total_txtytwth_pred = []
        B = HW = 0
        for pred in preds:
            B_, abC_, H_, W_ = pred.size()

            # [B, anchor_n * C, H, W] -> [B, H, W, anchor_n * C] -> [B, H*W, anchor_n*C]
            pred = pred.permute(0, 2, 3,
                                1).contiguous().view(B_, H_ * W_, abC_)

            # Divide prediction to obj_pred, xywh_pred and cls_pred
            # [B, H*W*anchor_n, 1]
            conf_pred = pred[:, :, :1 * self.num_anchors].contiguous().view(
                B_, H_ * W_ * self.num_anchors, 1)
            # [B, H*W*anchor_n, num_cls]
            cls_pred = pred[:, :, 1 * self.num_anchors:(1 + self.num_classes) *
                            self.num_anchors].contiguous().view(
                                B_, H_ * W_ * self.num_anchors,
                                self.num_classes)
            # [B, H*W*anchor_n, 4]
            txtytwth_pred = pred[:, :, (1 + self.num_classes) *
                                 self.num_anchors:].contiguous()

            total_conf_pred.append(conf_pred)
            total_cls_pred.append(cls_pred)
            total_txtytwth_pred.append(txtytwth_pred)
            B = B_
            HW += H_ * W_

        conf_pred = torch.cat(total_conf_pred, dim=1)
        cls_pred = torch.cat(total_cls_pred, dim=1)
        txtytwth_pred = torch.cat(total_txtytwth_pred, dim=1)

        # train
        if self.trainable:
            txtytwth_pred = txtytwth_pred.view(B, HW, self.num_anchors, 4)

            # 从txtytwth预测中解算出x1y1x2y2坐标
            x1y1x2y2_pred = (self.decode_boxes(txtytwth_pred) /
                             self.input_size).view(-1, 4)
            x1y1x2y2_gt = target[:, :, 7:].view(-1, 4)
            # 计算pred box与gt box之间的IoU
            iou_pred = tools.iou_score(x1y1x2y2_pred,
                                       x1y1x2y2_gt).view(B, -1, 1)

            # gt conf,这一操作是保证iou不会回传梯度
            with torch.no_grad():
                gt_conf = iou_pred.clone()

            # 我们讲pred box与gt box之间的iou作为objectness的学习目标.
            # [obj, cls, txtytwth, scale_weight, x1y1x2y2] -> [conf, obj, cls, txtytwth, scale_weight]
            target = torch.cat([gt_conf, target[:, :, :7]], dim=2)
            txtytwth_pred = txtytwth_pred.view(B, -1, 4)

            # 计算loss
            conf_loss, cls_loss, bbox_loss, iou_loss = tools.loss(
                pred_conf=conf_pred,
                pred_cls=cls_pred,
                pred_txtytwth=txtytwth_pred,
                pred_iou=iou_pred,
                label=target)

            return conf_loss, cls_loss, bbox_loss, iou_loss

        # test
        else:
            txtytwth_pred = txtytwth_pred.view(B, HW, self.num_anchors, 4)
            with torch.no_grad():
                # batch size = 1
                # 测试时,笔者默认batch是1,
                # 因此,我们不需要用batch这个维度,用[0]将其取走。
                # [B, H*W*num_anchor, 1] -> [H*W*num_anchor, 1]
                conf_pred = torch.sigmoid(conf_pred)[0]
                # [B, H*W*num_anchor, 4] -> [H*W*num_anchor, 4]
                bboxes = torch.clamp(
                    (self.decode_boxes(txtytwth_pred) / self.input_size)[0],
                    0., 1.)
                # [B, H*W*num_anchor, C] -> [H*W*num_anchor, C],
                scores = torch.softmax(cls_pred[0, :, :], dim=1) * conf_pred

                # 将预测放在cpu处理上,以便进行后处理
                scores = scores.to('cpu').numpy()
                bboxes = bboxes.to('cpu').numpy()

                # 后处理
                bboxes, scores, cls_inds = self.postprocess(bboxes, scores)

                return bboxes, scores, cls_inds
Пример #4
0
    def forward(self, x, target=None):
        # backbone
        c3, c4, c5 = self.backbone(x)

        # FPN, 多尺度特征融合
        p5 = self.conv_set_3(c5)
        p5_up = F.interpolate(self.conv_1x1_3(p5),
                              scale_factor=2.0,
                              mode='bilinear',
                              align_corners=True)

        p4 = torch.cat([c4, p5_up], 1)
        p4 = self.conv_set_2(p4)
        p4_up = F.interpolate(self.conv_1x1_2(p4),
                              scale_factor=2.0,
                              mode='bilinear',
                              align_corners=True)

        p3 = torch.cat([c3, p4_up], 1)
        p3 = self.conv_set_1(p3)

        # head
        # s = 32, 预测大物体
        p5 = self.extra_conv_3(p5)
        pred_3 = self.pred_3(p5)

        # s = 16, 预测中物体
        p4 = self.extra_conv_2(p4)
        pred_2 = self.pred_2(p4)

        # s = 8, 预测小物体
        p3 = self.extra_conv_1(p3)
        pred_1 = self.pred_1(p3)

        preds = [pred_1, pred_2, pred_3]
        total_conf_pred = []
        total_cls_pred = []
        total_txtytwth_pred = []
        B = HW = 0
        for pred in preds:
            B_, abC_, H_, W_ = pred.size()

            # 对pred 的size做一些view调整,便于后续的处理
            # [B, anchor_n * C, H, W] -> [B, H, W, anchor_n * C] -> [B, H*W, anchor_n*C]
            pred = pred.permute(0, 2, 3,
                                1).contiguous().view(B_, H_ * W_, abC_)

            # 从pred中分离出objectness预测、类别class预测、bbox的txtytwth预测
            # [B, H*W*anchor_n, 1]
            conf_pred = pred[:, :, :1 * self.num_anchors].contiguous().view(
                B_, H_ * W_ * self.num_anchors, 1)
            # [B, H*W*anchor_n, num_cls]
            cls_pred = pred[:, :, 1 * self.num_anchors:(1 + self.num_classes) *
                            self.num_anchors].contiguous().view(
                                B_, H_ * W_ * self.num_anchors,
                                self.num_classes)
            # [B, H*W*anchor_n, 4]
            txtytwth_pred = pred[:, :, (1 + self.num_classes) *
                                 self.num_anchors:].contiguous()

            total_conf_pred.append(conf_pred)
            total_cls_pred.append(cls_pred)
            total_txtytwth_pred.append(txtytwth_pred)
            B = B_
            HW += H_ * W_

        # 将所有结果沿着H*W这个维度拼接
        conf_pred = torch.cat(total_conf_pred, dim=1)
        cls_pred = torch.cat(total_cls_pred, dim=1)
        txtytwth_pred = torch.cat(total_txtytwth_pred, dim=1)

        # train
        if self.trainable:
            txtytwth_pred = txtytwth_pred.view(B, HW, self.num_anchors, 4)

            # 从txtytwth预测中解算出x1y1x2y2坐标
            x1y1x2y2_pred = (self.decode_boxes(txtytwth_pred) /
                             self.input_size).view(-1, 4)
            x1y1x2y2_gt = target[:, :, 7:].view(-1, 4)
            # 计算pred box与gt box之间的IoU
            iou_pred = tools.iou_score(x1y1x2y2_pred,
                                       x1y1x2y2_gt).view(B, -1, 1)

            # gt conf,这一操作是保证iou不会回传梯度
            with torch.no_grad():
                gt_conf = iou_pred.clone()

            # 我们讲pred box与gt box之间的iou作为objectness的学习目标.
            # [obj, cls, txtytwth, scale_weight, x1y1x2y2] -> [conf, obj, cls, txtytwth, scale_weight]
            target = torch.cat([gt_conf, target[:, :, :7]], dim=2)
            txtytwth_pred = txtytwth_pred.view(B, -1, 4)

            # 计算loss
            conf_loss, cls_loss, bbox_loss, iou_loss = tools.loss(
                pred_conf=conf_pred,
                pred_cls=cls_pred,
                pred_txtytwth=txtytwth_pred,
                pred_iou=iou_pred,
                label=target)

            return conf_loss, cls_loss, bbox_loss, iou_loss

        # test
        else:
            txtytwth_pred = txtytwth_pred.view(B, HW, self.num_anchors, 4)
            with torch.no_grad():
                # batch size = 1
                # 测试时,笔者默认batch是1,
                # 因此,我们不需要用batch这个维度,用[0]将其取走。
                # [B, H*W*num_anchor, 1] -> [H*W*num_anchor, 1]
                conf_pred = torch.sigmoid(conf_pred)[0]
                # [B, H*W*num_anchor, 4] -> [H*W*num_anchor, 4]
                bboxes = torch.clamp(
                    (self.decode_boxes(txtytwth_pred) / self.input_size)[0],
                    0., 1.)
                # [B, H*W*num_anchor, C] -> [H*W*num_anchor, C],
                scores = torch.softmax(cls_pred[0, :, :], dim=1) * conf_pred

                # 将预测放在cpu处理上,以便进行后处理
                scores = scores.to('cpu').numpy()
                bboxes = bboxes.to('cpu').numpy()

                # 后处理
                bboxes, scores, cls_inds = self.postprocess(bboxes, scores)

                return bboxes, scores, cls_inds
Пример #5
0
    def forward(self, x, target=None):
        # backbone主干网络
        _, c4, c5 = self.backbone(x)

        # head
        p5 = self.convsets_1(c5)

        # 处理c4特征
        p4 = self.reorg(self.route_layer(c4))

        # 融合
        p5 = torch.cat([p4, p5], dim=1)

        # head
        p5 = self.convsets_2(p5)

        # 预测
        prediction = self.pred(p5)

        B, abC, H, W = prediction.size()

        # [B, num_anchor * C, H, W] -> [B, H, W, num_anchor * C] -> [B, H*W, num_anchor*C]
        prediction = prediction.permute(0, 2, 3,
                                        1).contiguous().view(B, H * W, abC)

        # 从pred中分离出objectness预测、类别class预测、bbox的txtytwth预测
        # [B, H*W*num_anchor, 1]
        conf_pred = prediction[:, :, :1 * self.num_anchors].contiguous().view(
            B, H * W * self.num_anchors, 1)
        # [B, H*W, num_anchor, num_cls]
        cls_pred = prediction[:, :,
                              1 * self.num_anchors:(1 + self.num_classes) *
                              self.num_anchors].contiguous().view(
                                  B, H * W * self.num_anchors,
                                  self.num_classes)
        # [B, H*W, num_anchor, 4]
        txtytwth_pred = prediction[:, :, (1 + self.num_classes) *
                                   self.num_anchors:].contiguous()

        # train
        if self.trainable:
            txtytwth_pred = txtytwth_pred.view(B, H * W, self.num_anchors, 4)
            # decode bbox
            x1y1x2y2_pred = (self.decode_boxes(txtytwth_pred) /
                             self.input_size).view(-1, 4)
            x1y1x2y2_gt = target[:, :, 7:].view(-1, 4)

            # 计算预测框和真实框之间的IoU
            iou_pred = tools.iou_score(x1y1x2y2_pred,
                                       x1y1x2y2_gt).view(B, -1, 1)

            # 将IoU作为置信度的学习目标
            with torch.no_grad():
                gt_conf = iou_pred.clone()

            txtytwth_pred = txtytwth_pred.view(B, H * W * self.num_anchors, 4)
            # 将IoU作为置信度的学习目标
            # [obj, cls, txtytwth, x1y1x2y2] -> [conf, obj, cls, txtytwth]
            target = torch.cat([gt_conf, target[:, :, :7]], dim=2)

            # 计算损失
            conf_loss, cls_loss, bbox_loss, iou_loss = tools.loss(
                pred_conf=conf_pred,
                pred_cls=cls_pred,
                pred_txtytwth=txtytwth_pred,
                pred_iou=iou_pred,
                label=target)

            return conf_loss, cls_loss, bbox_loss, iou_loss

        # test
        else:
            txtytwth_pred = txtytwth_pred.view(B, H * W, self.num_anchors, 4)
            with torch.no_grad():
                # batch size = 1
                # 测试时,笔者默认batch是1,
                # 因此,我们不需要用batch这个维度,用[0]将其取走。
                # [B, H*W*num_anchor, 1] -> [H*W*num_anchor, 1]
                conf_pred = torch.sigmoid(conf_pred)[0]
                # [B, H*W*num_anchor, 4] -> [H*W*num_anchor, 4]
                bboxes = torch.clamp(
                    (self.decode_boxes(txtytwth_pred) / self.input_size)[0],
                    0., 1.)
                # [B, H*W*num_anchor, C] -> [H*W*num_anchor, C],
                scores = torch.softmax(cls_pred[0, :, :], dim=1) * conf_pred

                # 将预测放在cpu处理上,以便进行后处理
                scores = scores.to('cpu').numpy()
                bboxes = bboxes.to('cpu').numpy()

                # 后处理
                bboxes, scores, cls_inds = self.postprocess(bboxes, scores)

                return bboxes, scores, cls_inds
Пример #6
0
    def forward(self, x, target=None):
        # backbone
        c3, c4, c5 = self.backbone(x)

        p3 = self.conv1x1_0(c3)
        p4 = self.conv1x1_1(c4)
        p5 = self.conv1x1_2(c5)

        # FPN
        p4 = self.smooth_0(p4 + F.interpolate(p5, scale_factor=2.0))
        p3 = self.smooth_1(p3 + F.interpolate(p4, scale_factor=2.0))

        # PAN
        p4 = self.smooth_2(p4 + F.interpolate(p3, scale_factor=0.5))
        p5 = self.smooth_3(p5 + F.interpolate(p4, scale_factor=0.5))

        # det head
        pred_s = self.head_det_1(p3)
        pred_m = self.head_det_2(p4)
        pred_l = self.head_det_3(p5)

        preds = [pred_s, pred_m, pred_l]
        total_conf_pred = []
        total_cls_pred = []
        total_txtytwth_pred = []
        B = HW = 0
        for pred in preds:
            B_, abC_, H_, W_ = pred.size()

            # [B, anchor_n * C, H, W] -> [B, H, W, anchor_n * C] -> [B, H*W, anchor_n*C]
            pred = pred.permute(0, 2, 3,
                                1).contiguous().view(B_, H_ * W_, abC_)

            # Divide prediction to obj_pred, xywh_pred and cls_pred
            # [B, H*W*anchor_n, 1]
            conf_pred = pred[:, :, :1 * self.num_anchors].contiguous().view(
                B_, H_ * W_ * self.num_anchors, 1)
            # [B, H*W*anchor_n, num_cls]
            cls_pred = pred[:, :, 1 * self.num_anchors:(1 + self.num_classes) *
                            self.num_anchors].contiguous().view(
                                B_, H_ * W_ * self.num_anchors,
                                self.num_classes)
            # [B, H*W*anchor_n, 4]
            txtytwth_pred = pred[:, :, (1 + self.num_classes) *
                                 self.num_anchors:].contiguous()

            total_conf_pred.append(conf_pred)
            total_cls_pred.append(cls_pred)
            total_txtytwth_pred.append(txtytwth_pred)
            B = B_
            HW += H_ * W_

        conf_pred = torch.cat(total_conf_pred, 1)
        cls_pred = torch.cat(total_cls_pred, 1)
        txtytwth_pred = torch.cat(total_txtytwth_pred, 1)  #.view(B, -1, 4)

        # train
        if self.trainable:
            txtytwth_pred = txtytwth_pred.view(B, HW, self.num_anchors, 4)
            # decode bbox
            x1y1x2y2_pred = (self.decode_boxes(txtytwth_pred) /
                             self.input_size).view(-1, 4)
            x1y1x2y2_gt = target[:, :, 7:].view(-1, 4)
            # compute iou
            iou_pred = tools.iou_score(x1y1x2y2_pred,
                                       x1y1x2y2_gt,
                                       batch_size=B)

            # gt conf
            with torch.no_grad():
                gt_conf = iou_pred.clone()

            # we set iou between pred bbox and gt bbox as conf label.
            # [obj, cls, txtytwth, x1y1x2y2] -> [conf, obj, cls, txtytwth]
            target = torch.cat([gt_conf, target[:, :, :7]], dim=2)

            txtytwth_pred = txtytwth_pred.view(B, -1, 4)
            conf_loss, cls_loss, bbox_loss, iou_loss = tools.loss(
                pred_conf=conf_pred,
                pred_cls=cls_pred,
                pred_txtytwth=txtytwth_pred,
                pred_iou=iou_pred,
                label=target)

            return conf_loss, cls_loss, bbox_loss, iou_loss

        # test
        else:
            txtytwth_pred = txtytwth_pred.view(B, HW, self.num_anchors, 4)
            with torch.no_grad():
                # batch size = 1
                all_obj = torch.sigmoid(conf_pred)[
                    0]  # 0 is because that these is only 1 batch.
                all_bbox = torch.clamp(
                    (self.decode_boxes(txtytwth_pred) / self.input_size)[0],
                    0., 1.)
                all_class = (torch.softmax(cls_pred[0, :, :], dim=1) * all_obj)
                # all_class = (torch.sigmoid(cls_pred[0, :, :]) * all_obj)
                # separate box pred and class conf
                all_class = all_class.to('cpu').numpy()
                all_bbox = all_bbox.to('cpu').numpy()

                bboxes, scores, cls_inds = self.postprocess(
                    all_bbox, all_class)

                # print(len(all_boxes))
                return bboxes, scores, cls_inds