Exemple #1
0
    def forward(self, xin, labels=None):
        """
        In this
        Args:
            xin (torch.Tensor): input feature map whose size is :math:`(N, C, H, W)`, \
                where N, C, H, W denote batchsize, channel width, height, width respectively.
            labels (torch.Tensor): label data whose size is :math:`(N, K, 5)`. \
                N and K denote batchsize and number of labels.
                Each label consists of [class, xc, yc, w, h]:
                    class (float): class index.
                    xc, yc (float) : center of bbox whose values range from 0 to 1.
                    w, h (float) : size of bbox whose values range from 0 to 1.
        Returns:
            loss (torch.Tensor): total loss - the target of backprop.
            loss_xy (torch.Tensor): x, y loss - calculated by binary cross entropy (BCE) \
                with boxsize-dependent weights.
            loss_wh (torch.Tensor): w, h loss - calculated by l2 without size averaging and \
                with boxsize-dependent weights.
            loss_obj (torch.Tensor): objectness loss - calculated by BCE.
            loss_cls (torch.Tensor): classification loss - calculated by BCE for each class.
            loss_l2 (torch.Tensor): total l2 loss - only for logging.
        """

        wh_pred = self.guide_wh(xin) #Anchor guiding

        if xin.type() == 'torch.cuda.HalfTensor': #As DCN only support FP32 now, change the feature to float.
            wh_pred = wh_pred.float()
            if labels is not None:
                labels = labels.float()
            self.Feature_adaption = self.Feature_adaption.float()
            self.conv = self.conv.float()
            xin = xin.float()

        feature_adapted = self.Feature_adaption(xin, wh_pred)

        output = self.conv(feature_adapted)
        wh_pred = torch.exp(wh_pred)

        batchsize = output.shape[0]
        fsize = output.shape[2]
        image_size = fsize * self.stride
        n_ch = 5 + self.n_classes
        dtype = torch.cuda.FloatTensor if xin.is_cuda else torch.FloatTensor

        wh_pred = wh_pred.view(batchsize, self.n_anchors, 2 , fsize, fsize)
        wh_pred = wh_pred.permute(0, 1, 3, 4, 2).contiguous()

        output = output.view(batchsize, self.n_anchors, n_ch, fsize, fsize)
        output = output.permute(0,1,3,4,2).contiguous()

        x_shift = dtype(np.broadcast_to(
            np.arange(fsize, dtype=np.float32), output.shape[:4]))
        y_shift = dtype(np.broadcast_to(
            np.arange(fsize, dtype=np.float32).reshape(fsize, 1), output.shape[:4]))

        masked_anchors = np.array(self.masked_anchors)

        w_anchors = dtype(np.broadcast_to(np.reshape(
            masked_anchors[:, 0], (1, self.n_anchors-1, 1, 1)), [batchsize, self.n_anchors-1, fsize, fsize]))
        h_anchors = dtype(np.broadcast_to(np.reshape(
            masked_anchors[:, 1], (1, self.n_anchors-1, 1, 1)), [batchsize, self.n_anchors-1, fsize, fsize]))

        default_center = torch.zeros(batchsize, self.n_anchors, fsize, fsize, 2).type(dtype)

        pred_anchors = torch.cat((default_center, wh_pred), dim=-1).contiguous()

        anchors_based = pred_anchors[:, :self.n_anchors-1, :, :, :]   #anchor branch
        anchors_free = pred_anchors[:, self.n_anchors-1, :, :, :]     #anchor free branch
        anchors_based[...,2] *= w_anchors
        anchors_based[...,3] *= h_anchors
        anchors_free[...,2] *= self.stride*4
        anchors_free[...,3] *= self.stride*4
        pred_anchors[...,:2] = pred_anchors[...,:2].detach()

        if not self.training:

            pred = output.clone()
            pred[..., np.r_[:2, 4:n_ch]] = torch.sigmoid(
                    pred[...,np.r_[:2, 4:n_ch]])
            pred[...,0] += x_shift
            pred[...,1] += y_shift
            pred[...,:2] *= self.stride
            pred[...,2] = torch.exp(pred[...,2])*(pred_anchors[...,2])
            pred[...,3] = torch.exp(pred[...,3])*(pred_anchors[...,3])
            refined_pred = pred.view(batchsize, -1, n_ch)
            return refined_pred.data

        #training for anchor prediction
        if self.training:

            target = torch.zeros(batchsize, self.n_anchors,
                                fsize, fsize, n_ch).type(dtype)
            l1_target = torch.zeros(batchsize, self.n_anchors,
                                fsize, fsize, 4).type(dtype)
            tgt_scale = torch.zeros(batchsize, self.n_anchors,
                                fsize, fsize, 4).type(dtype)
            obj_mask = torch.ones(batchsize, self.n_anchors, fsize, fsize).type(dtype)

            cls_mask = torch.zeros(batchsize, self.n_anchors, fsize, fsize, self.n_classes).type(dtype)
            coord_mask = torch.zeros(batchsize, self.n_anchors, fsize, fsize).type(dtype)
            anchor_mask = torch.zeros(batchsize, self.n_anchors, fsize, fsize).type(dtype)

            labels = labels.data
            mixup = labels.shape[2]>5
            if mixup:
                label_cut = labels[...,:5]
            else:
                label_cut = labels
            nlabel = (label_cut.sum(dim=2) > 0).sum(dim=1)  # number of objects

            truth_x_all = labels[:, :, 1] * 1.
            truth_y_all = labels[:, :, 2] * 1.
            truth_w_all = labels[:, :, 3] * 1.
            truth_h_all = labels[:, :, 4] * 1.
            truth_i_all = (truth_x_all/image_size*fsize).to(torch.int16).cpu().numpy()
            truth_j_all = (truth_y_all/image_size*fsize).to(torch.int16).cpu().numpy()

            pred = output.clone()
            pred[..., np.r_[:2, 4:n_ch]] = torch.sigmoid(
                    pred[...,np.r_[:2, 4:n_ch]])
            pred[...,0] += x_shift
            pred[...,1] += y_shift
            pred[...,2] = torch.exp(pred[...,2])*(pred_anchors[...,2])
            pred[...,3] = torch.exp(pred[...,3])*(pred_anchors[...,3])
            pred[...,:2] *= self.stride

            pred_boxes = pred[...,:4].data
            for b in range(batchsize):
                n = int(nlabel[b])
                if n == 0:
                    continue

                truth_box = dtype(np.zeros((n, 4)))
                truth_box[:n, 2] = truth_w_all[b, :n]
                truth_box[:n, 3] = truth_h_all[b, :n]
                truth_i = truth_i_all[b, :n]
                truth_j = truth_j_all[b, :n]

                # calculate iou between truth and reference anchors
                anchor_ious_all = bboxes_iou(truth_box.cpu(), self.ref_anchors, xyxy=False)
                best_n_all = np.argmax(anchor_ious_all, axis=1)
                best_anchor_iou = anchor_ious_all[np.arange(anchor_ious_all.shape[0]),best_n_all]
                best_n = best_n_all % 3
                best_n_mask = ((best_n_all == self.anch_mask[0]) | (
                    best_n_all == self.anch_mask[1]) | (best_n_all == self.anch_mask[2]))

                truth_box[:n, 0] = truth_x_all[b, :n]
                truth_box[:n, 1] = truth_y_all[b, :n]
                pred_box = pred_boxes[b]
                pred_ious = bboxes_iou(pred_box.view(-1,4),
                        truth_box, xyxy=False)
                pred_best_iou, _= pred_ious.max(dim=1)
                pred_best_iou = (pred_best_iou > self.ignore_thre)
                pred_best_iou = pred_best_iou.view(pred_box.shape[:3])
                obj_mask[b]= ~pred_best_iou
                truth_box[:n, 0] = 0
                truth_box[:n, 1] = 0

                if sum(best_n_mask) == 0:
                    continue
                for ti in range(best_n.shape[0]):
                    if best_n_mask[ti] == 1:
                        i, j = truth_i[ti], truth_j[ti]
                        a = best_n[ti]
                        free_iou = bboxes_iou(truth_box[ti].cpu().view(-1,4),
                                pred_anchors[b, self.n_anchors-1, j, i, :4].data.cpu().view(-1,4),xyxy=False)  #iou of pred anchor 

                        #choose the best anchor
                        if free_iou > best_anchor_iou[ti]:
                            aa = self.n_anchors-1
                        else:
                            aa = a

                        cls_mask[b, aa, j, i, :] = 1
                        coord_mask[b, aa, j, i] = 1

                        anchor_mask[b, self.n_anchors-1, j, i] = 1
                        anchor_mask[b, a, j, i] = 1

                        obj_mask[b, aa, j, i]= 1 if not mixup else labels[b, ti, 5]

                        target[b, a, j, i, 0] = truth_x_all[b, ti]
                        target[b, a, j, i, 1] = truth_y_all[b, ti]
                        target[b, a, j, i, 2] = truth_w_all[b, ti]
                        target[b, a, j, i, 3] = truth_h_all[b, ti]

                        target[b, self.n_anchors-1, j, i, 0] = truth_x_all[b, ti]
                        target[b, self.n_anchors-1, j, i, 1] = truth_y_all[b, ti]
                        target[b, self.n_anchors-1, j, i, 2] = truth_w_all[b, ti]
                        target[b, self.n_anchors-1, j, i, 3] = truth_h_all[b, ti]

                        l1_target[b, aa, j, i, 0] = truth_x_all[b, ti]/image_size *fsize - i*1.0
                        l1_target[b, aa, j, i, 1] = truth_y_all[b, ti]/image_size *fsize - j*1.0
                        l1_target[b, aa, j, i, 2] = torch.log(truth_w_all[b, ti]/\
                            (pred_anchors[b, aa, j, i, 2])+ 1e-12)
                        l1_target[b, aa, j, i, 3] = torch.log(truth_h_all[b, ti]/\
                            (pred_anchors[b, aa, j, i, 3]) + 1e-12)
                        target[b, aa, j, i, 4] = 1
                        if self._label_smooth:
                            smooth_delta = 1
                            smooth_weight = 1. / self.n_classes
                            target[b, aa, j, i, 5:]= smooth_weight* smooth_delta

                            target[b, aa, j, i, 5 + labels[b, ti,
                                0].to(torch.int16)] = 1 - smooth_delta*smooth_weight
                        else:
                            target[b,aa, j, i, 5 + labels[b, ti,
                                0].to(torch.int16)] = 1

                        tgt_scale[b, aa,j, i, :] = 2.0 - truth_w_all[b, ti]*truth_h_all[b, ti] / image_size/image_size


            # Anchor loss
            anchorcoord_mask = anchor_mask>0
            loss_anchor = self.iou_wh_loss(pred_anchors[...,:4][anchorcoord_mask], target[...,:4][anchorcoord_mask]).sum()/batchsize

            #Prediction loss
            coord_mask = coord_mask>0
            loss_iou = (tgt_scale[coord_mask][...,0]*\
                    self.iou_loss(pred[..., :4][coord_mask], target[..., :4][coord_mask])).sum() / batchsize
            tgt_scale = tgt_scale[...,:2]
            loss_xy = (tgt_scale*self.bcewithlog_loss(output[...,:2], l1_target[...,:2])).sum() / batchsize
            loss_wh = (tgt_scale*self.l1_loss(output[...,2:4], l1_target[...,2:4])).sum() / batchsize
            loss_l1 = loss_xy + loss_wh
            loss_obj = (obj_mask*(self.bcewithlog_loss(output[..., 4], target[..., 4]))).sum() / batchsize
            loss_cls = (cls_mask*(self.bcewithlog_loss(output[..., 5:], target[..., 5:]))).sum()/ batchsize

            loss = loss_anchor + loss_iou + loss_l1+ loss_obj + loss_cls

            return loss, loss_anchor, loss_iou, loss_l1, loss_obj, loss_cls
    def forward(self, xin, labels=None):
        """
        In this
        Args:
            xin (torch.Tensor): input feature map whose size is :math:`(N, C, H, W)`, \
                where N, C, H, W denote batchsize, channel width, height, width respectively.
            labels (torch.Tensor): label data whose size is :math:`(N, K, 5)`. \
                N and K denote batchsize and number of labels.
                Each label consists of [class, xc, yc, w, h]:
                    class (float): class index.
                    xc, yc (float) : center of bbox whose values range from 0 to 1.
                    w, h (float) : size of bbox whose values range from 0 to 1.
        Returns:
            loss (torch.Tensor): total loss - the target of backprop.
            loss_xy (torch.Tensor): x, y loss - calculated by binary cross entropy (BCE) \
                with boxsize-dependent weights.
            loss_wh (torch.Tensor): w, h loss - calculated by l2 without size averaging and \
                with boxsize-dependent weights.
            loss_obj (torch.Tensor): objectness loss - calculated by BCE.
            loss_cls (torch.Tensor): classification loss - calculated by BCE for each class.
            loss_l2 (torch.Tensor): total l2 loss - only for logging.
        """
        if not self.use_cfg:
            output = self.conv(xin)
        else:
            output = xin

        batchsize = output.shape[0]
        fsize = output.shape[2]
        n_ch = 5 + self.n_classes
        dtype = torch.cuda.FloatTensor if xin.is_cuda else torch.FloatTensor

        output = output.view(batchsize, self.n_anchors, n_ch, fsize, fsize)
        output = output.permute(0, 1, 3, 4, 2)  # .contiguous()

        # logistic activation for xy, obj, cls
        output[..., np.r_[:2, 4:n_ch]] = torch.sigmoid(
            output[..., np.r_[:2, 4:n_ch]])

        # calculate pred - xywh obj cls

        x_shift = dtype(np.broadcast_to(
            np.arange(fsize, dtype=np.float32), output.shape[:4]))
        y_shift = dtype(np.broadcast_to(
            np.arange(fsize, dtype=np.float32).reshape(fsize, 1), output.shape[:4]))

        masked_anchors = np.array(self.masked_anchors)

        w_anchors = dtype(np.broadcast_to(np.reshape(
            masked_anchors[:, 0], (1, self.n_anchors, 1, 1)), output.shape[:4]))
        h_anchors = dtype(np.broadcast_to(np.reshape(
            masked_anchors[:, 1], (1, self.n_anchors, 1, 1)), output.shape[:4]))

        pred = output.clone()
        pred[..., 0] += x_shift
        pred[..., 1] += y_shift
        pred[..., 2] = torch.exp(pred[..., 2]) * w_anchors
        pred[..., 3] = torch.exp(pred[..., 3]) * h_anchors

        if labels is None:  # not training
            pred[..., :4] *= self.stride
            return pred.view(batchsize, -1, n_ch).data

        pred = pred[..., :4].data

        # target assignment

        tgt_mask = torch.zeros(batchsize, self.n_anchors,
                               fsize, fsize, 4 + self.n_classes).type(dtype)
        obj_mask = torch.ones(batchsize, self.n_anchors,
                              fsize, fsize).type(dtype)
        tgt_scale = torch.zeros(batchsize, self.n_anchors,
                                fsize, fsize, 2).type(dtype)

        target = torch.zeros(batchsize, self.n_anchors,
                             fsize, fsize, n_ch).type(dtype)

        labels = labels.cpu().data
        nlabel = (labels.sum(dim=2) > 0).sum(dim=1)  # number of objects

        truth_x_all = labels[:, :, 1] * fsize
        truth_y_all = labels[:, :, 2] * fsize
        truth_w_all = labels[:, :, 3] * fsize
        truth_h_all = labels[:, :, 4] * fsize
        truth_i_all = truth_x_all.to(torch.int16).numpy()
        truth_j_all = truth_y_all.to(torch.int16).numpy()

        for b in range(batchsize):
            n = int(nlabel[b])
            if n == 0:
                continue
            truth_box = dtype(np.zeros((n, 4)))
            truth_box[:n, 2] = truth_w_all[b, :n]
            truth_box[:n, 3] = truth_h_all[b, :n]
            truth_i = truth_i_all[b, :n]
            truth_j = truth_j_all[b, :n]

            # calculate iou between truth and reference anchors
            anchor_ious_all = bboxes_iou(truth_box.cpu(), self.ref_anchors)
            best_n_all = np.argmax(anchor_ious_all, axis=1)
            best_n = best_n_all % 3
            best_n_mask = ((best_n_all == self.anch_mask[0]) | (
                best_n_all == self.anch_mask[1]) | (best_n_all == self.anch_mask[2]))

            truth_box[:n, 0] = truth_x_all[b, :n]
            truth_box[:n, 1] = truth_y_all[b, :n]

            pred_ious = bboxes_iou(
                pred[b].view(-1, 4), truth_box, xyxy=False)
            pred_best_iou, _ = pred_ious.max(dim=1)
            pred_best_iou = (pred_best_iou > self.ignore_thre)
            pred_best_iou = pred_best_iou.view(pred[b].shape[:3])
            # set mask to zero (ignore) if pred matches truth
            obj_mask[b] = 1 - pred_best_iou

            if sum(best_n_mask) == 0:
                continue

            for ti in range(best_n.shape[0]):
                if best_n_mask[ti] == 1:
                    i, j = truth_i[ti], truth_j[ti]
                    a = best_n[ti]
                    obj_mask[b, a, j, i] = 1
                    tgt_mask[b, a, j, i, :] = 1
                    target[b, a, j, i, 0] = truth_x_all[b, ti] - \
                        truth_x_all[b, ti].to(torch.int16).to(torch.float)
                    target[b, a, j, i, 1] = truth_y_all[b, ti] - \
                        truth_y_all[b, ti].to(torch.int16).to(torch.float)
                    target[b, a, j, i, 2] = torch.log(
                        truth_w_all[b, ti] / torch.Tensor(self.masked_anchors)[best_n[ti], 0] + 1e-16)
                    target[b, a, j, i, 3] = torch.log(
                        truth_h_all[b, ti] / torch.Tensor(self.masked_anchors)[best_n[ti], 1] + 1e-16)
                    target[b, a, j, i, 4] = 1
                    target[b, a, j, i, 5 + labels[b, ti,
                                                  0].to(torch.int16).numpy()] = 1
                    tgt_scale[b, a, j, i, :] = torch.sqrt(
                        2 - truth_w_all[b, ti] * truth_h_all[b, ti] / fsize / fsize)

        # loss calculation

        output[..., 4] *= obj_mask
        output[..., np.r_[0:4, 5:n_ch]] *= tgt_mask
        output[..., 2:4] *= tgt_scale

        target[..., 4] *= obj_mask
        target[..., np.r_[0:4, 5:n_ch]] *= tgt_mask
        target[..., 2:4] *= tgt_scale

        bceloss = nn.BCELoss(weight=tgt_scale*tgt_scale)  # weighted BCEloss
        loss_xy = bceloss(output[..., :2], target[..., :2])
        loss_wh = self.l2_loss(output[..., 2:4], target[..., 2:4]) / 2
        loss_obj = self.bce_loss(output[..., 4], target[..., 4])
        loss_cls = self.bce_loss(output[..., 5:], target[..., 5:])
        loss_l2 = self.l2_loss(output, target)

        loss = loss_xy + loss_wh + loss_obj + loss_cls

        return loss, loss_xy, loss_wh, loss_obj, loss_cls, loss_l2
    def forward(self, xin, labels=None):
        """
        In this
        Args:
            xin (torch.Tensor): input feature map whose size is :math:`(N, C, H, W)`, \
                where N, C, H, W denote batchsize, channel width, height, width respectively.
            labels (torch.Tensor): label data whose size is :math:`(N, K, 5)`. \
                N and K denote batchsize and number of labels.
                Each label consists of [class, xc, yc, w, h]:
                    class (float): class index.
                    xc, yc (float) : center of bbox whose values range from 0 to 1.
                    w, h (float) : size of bbox whose values range from 0 to 1.
        Returns:
            loss (torch.Tensor): total loss - the target of backprop.
            loss_xy (torch.Tensor): x, y loss - calculated by binary cross entropy (BCE) \
                with boxsize-dependent weights.
            loss_wh (torch.Tensor): w, h loss - calculated by l2 without size averaging and \
                with boxsize-dependent weights.
            loss_obj (torch.Tensor): objectness loss - calculated by BCE.
            loss_cls (torch.Tensor): classification loss - calculated by BCE for each class.
            loss_l2 (torch.Tensor): total l2 loss - only for logging.
        """
        output = self.conv(xin)

        batchsize = output.shape[0]
        fsize = output.shape[2]
        n_ch = 5 + self.n_classes  # channels per anchor w/o xywh unceartainties
        dtype = torch.cuda.FloatTensor if xin.is_cuda else torch.FloatTensor

        output = output.view(batchsize, self.n_anchors, -1, fsize, fsize)
        output = output.permute(
            0, 1, 3, 4,
            2)  # shape: [batch, anchor, grid_y, grid_x, channels_per_anchor]

        if self.gaussian:
            # logistic activation for sigma of xywh
            sigma_xywh = output[
                ...,
                -4:]  # shape: [batch, anchor, grid_y, grid_x, 4(= xywh uncertainties)]
            sigma_xywh = torch.sigmoid(sigma_xywh)

            output = output[..., :-4]
        # output shape: [batch, anchor, grid_y, grid_x, n_class + 5(= x, y, w, h, objectness)]

        # logistic activation for xy, obj, cls
        output[..., np.r_[:2, 4:n_ch]] = torch.sigmoid(output[...,
                                                              np.r_[:2,
                                                                    4:n_ch]])

        # calculate pred - xywh obj cls

        x_shift = dtype(
            np.broadcast_to(np.arange(fsize, dtype=np.float32),
                            output.shape[:4]))
        y_shift = dtype(
            np.broadcast_to(
                np.arange(fsize, dtype=np.float32).reshape(fsize, 1),
                output.shape[:4]))

        masked_anchors = np.array(self.masked_anchors)

        w_anchors = dtype(
            np.broadcast_to(
                np.reshape(masked_anchors[:, 0], (1, self.n_anchors, 1, 1)),
                output.shape[:4]))
        h_anchors = dtype(
            np.broadcast_to(
                np.reshape(masked_anchors[:, 1], (1, self.n_anchors, 1, 1)),
                output.shape[:4]))

        pred = output.clone()
        pred[..., 0] += x_shift
        pred[..., 1] += y_shift
        pred[..., 2] = torch.exp(pred[..., 2]) * w_anchors
        pred[..., 3] = torch.exp(pred[..., 3]) * h_anchors

        if labels is None:  # not training
            pred[..., :4] *= self.stride
            pred = pred.view(
                batchsize, -1,
                n_ch)  # shsape: [batch, anchor x grid_y x grid_x, n_class + 5]

            if self.gaussian:
                # scale objectness confidence with xywh uncertainties
                sigma_xywh = sigma_xywh.view(
                    batchsize, -1,
                    4)  # shsape: [batch, anchor x grid_y x grid_x, 4]
                sigma = sigma_xywh.mean(dim=-1)
                pred[..., 4] *= (1.0 - sigma)

                # unnormalize uncertainties
                sigma_xywh = torch.sqrt(sigma_xywh)
                sigma_xywh[..., :2] *= self.stride
                sigma_xywh[..., 2:] = torch.exp(sigma_xywh[..., 2:])

                # concat pred with uncertainties
                pred = torch.cat([
                    pred, sigma_xywh
                ], 2)  # shsape: [batch, anchor x grid_y x grid_x, n_class + 9]

            return pred.data

        pred = pred[
            ..., :
            4].data  # shape: [batch, anchor, grid_y, grid_x, 4(= x, y, w, h)]

        # target assignment

        tgt_mask = torch.zeros(batchsize, self.n_anchors, fsize, fsize,
                               4 + self.n_classes).type(dtype)
        obj_mask = torch.ones(batchsize, self.n_anchors, fsize,
                              fsize).type(dtype)
        tgt_scale = torch.zeros(batchsize, self.n_anchors, fsize, fsize,
                                2).type(dtype)

        target = torch.zeros(batchsize, self.n_anchors, fsize, fsize,
                             n_ch).type(dtype)

        labels = labels.cpu().data
        nlabel = (labels.sum(dim=2) > 0).sum(dim=1)  # number of objects

        truth_x_all = labels[:, :, 1] * fsize
        truth_y_all = labels[:, :, 2] * fsize
        truth_w_all = labels[:, :, 3] * fsize
        truth_h_all = labels[:, :, 4] * fsize
        truth_i_all = truth_x_all.to(torch.int16).numpy()
        truth_j_all = truth_y_all.to(torch.int16).numpy()

        for b in range(batchsize):
            n = int(nlabel[b])
            if n == 0:
                continue
            truth_box = dtype(np.zeros((n, 4)))
            truth_box[:n, 2] = truth_w_all[b, :n]
            truth_box[:n, 3] = truth_h_all[b, :n]
            truth_i = truth_i_all[b, :n]
            truth_j = truth_j_all[b, :n]

            # calculate iou between truth and reference anchors
            anchor_ious_all = bboxes_iou(truth_box.cpu(), self.ref_anchors)
            best_n_all = np.argmax(anchor_ious_all, axis=1)
            best_n = best_n_all % 3
            best_n_mask = ((best_n_all == self.anch_mask[0]) |
                           (best_n_all == self.anch_mask[1]) |
                           (best_n_all == self.anch_mask[2]))

            truth_box[:n, 0] = truth_x_all[b, :n]
            truth_box[:n, 1] = truth_y_all[b, :n]

            pred_ious = bboxes_iou(pred[b].view(-1, 4), truth_box, xyxy=False)
            pred_best_iou, _ = pred_ious.max(dim=1)
            pred_best_iou = (pred_best_iou > self.ignore_thre)
            pred_best_iou = pred_best_iou.view(pred[b].shape[:3])
            # set mask to zero (ignore) if pred matches truth
            obj_mask[b] = 1 - pred_best_iou

            if sum(best_n_mask) == 0:
                continue

            for ti in range(best_n.shape[0]):
                if best_n_mask[ti] == 1:
                    i, j = truth_i[ti], truth_j[ti]
                    a = best_n[ti]
                    obj_mask[b, a, j, i] = 1
                    tgt_mask[b, a, j, i, :] = 1
                    target[b, a, j, i, 0] = truth_x_all[b, ti] - \
                        truth_x_all[b, ti].to(torch.int16).to(torch.float)
                    target[b, a, j, i, 1] = truth_y_all[b, ti] - \
                        truth_y_all[b, ti].to(torch.int16).to(torch.float)
                    target[b, a, j, i, 2] = torch.log(
                        truth_w_all[b, ti] /
                        torch.Tensor(self.masked_anchors)[best_n[ti], 0] +
                        1e-16)
                    target[b, a, j, i, 3] = torch.log(
                        truth_h_all[b, ti] /
                        torch.Tensor(self.masked_anchors)[best_n[ti], 1] +
                        1e-16)
                    target[b, a, j, i, 4] = 1
                    target[b, a, j, i,
                           5 + labels[b, ti, 0].to(torch.int16).numpy()] = 1
                    tgt_scale[b, a, j, i, :] = 2 - truth_w_all[
                        b, ti] * truth_h_all[b, ti] / fsize / fsize

        # loss calculation
        output[..., 4] *= obj_mask
        output[..., np.r_[0:4, 5:n_ch]] *= tgt_mask

        target[..., 4] *= obj_mask
        target[..., np.r_[0:4, 5:n_ch]] *= tgt_mask

        loss_obj = F.binary_cross_entropy(output[..., 4],
                                          target[..., 4],
                                          reduction='sum')
        loss_cls = F.binary_cross_entropy(output[..., 5:],
                                          target[..., 5:],
                                          reduction='sum')

        if self.gaussian:
            loss_xy = -torch.log(
                self._gaussian_dist_pdf(output[..., :2], target[..., :2],
                                        sigma_xywh[..., :2]) + 1e-9) / 2.0
            loss_wh = -torch.log(
                self._gaussian_dist_pdf(output[..., 2:4], target[..., 2:4],
                                        sigma_xywh[..., 2:4]) + 1e-9) / 2.0
        else:
            loss_xy = F.binary_cross_entropy(output[..., :2],
                                             target[..., :2],
                                             reduction='none')
            loss_wh = F.mse_loss(
                output[..., 2:4], target[..., 2:4], reduction='none') / 2.0
        loss_xy = (loss_xy * tgt_scale).sum()
        loss_wh = (loss_wh * tgt_scale).sum()

        loss = loss_xy + loss_wh + loss_obj + loss_cls

        return loss, loss_xy, loss_wh, loss_obj, loss_cls