Exemplo n.º 1
0
    def forward(self, imgs, locs, label_class=None, label_box=None):
        '''
        Param:
        imgs:        F(b, 3, vsz, vsz)
        locs:        F(b, 4)
        label_class: L(b, N_max) or None
        label_box:   F(b, N_max, 4) or None

        Return 1:
        loss:        F(b)

        Return 2:
        pred_cls_i:  L(b, topk)
        pred_cls_p:  F(b, topk)
        pred_reg:    F(b, topk, 4)
        '''
        
        # forward fpn
        C3, C4, C5 = self.backbone(imgs)
        P5 = self.prj_5(C5)
        P4 = self.prj_4(C4)
        P3 = self.prj_3(C3)
        P4 = P4 + self.upsample(P5)
        P3 = P3 + self.upsample(P4)
        P3 = self.conv_3(P3)
        P4 = self.conv_4(P4)
        P5 = self.conv_5(P5)
        P6 = self.conv_out6(C5)
        P7 = self.conv_out7(self.relu(P6))
        pred_list = [P3, P4, P5, P6, P7]
        assert len(pred_list) == len(self.r)

        # get pred
        pred_cls = []
        pred_reg = []
        for i, feature in enumerate(pred_list):
            cls_i = self.conv_cls(feature)
            reg_i =  self.conv_reg(feature) * self.scale_param[i]
            cls_i = cls_i.permute(0,2,3,1).contiguous()
            reg_i = reg_i.permute(0,2,3,1).contiguous() # b, ph, pw, 4
            reg_i = self.decode_box(reg_i, self.view_size, self.view_size, self.phpw[i][0], self.phpw[i][1])
            pred_cls.append(cls_i.view(cls_i.shape[0], -1, self.classes))
            pred_reg.append(reg_i.view(reg_i.shape[0], -1, 4))
        pred_cls = torch.cat(pred_cls, dim=1)
        pred_reg = torch.cat(pred_reg, dim=1)
        # pred_cls: F(b, n, classes)
        # pred_reg: F(b, n, 4)

        if (label_class is not None) and (label_box is not None):
            # <= 200
            n_max = min(label_class.shape[1], 200)
            if n_max == 200:
                label_class = label_class[:, :200]
                label_box   = label_box[:, :200, :]
            # get target
            target_cls = []
            target_reg = []
            for i in range(len(self.r)):
                target_cls_i, target_reg_i = assign_box(label_class, label_box, locs,
                    self.view_size, self.view_size, self.phpw[i][0], self.phpw[i][1],
                    self.tlbr_max_minmax[i][0], self.tlbr_max_minmax[i][1], self.r[i])
                target_cls.append(target_cls_i.view(target_cls_i.shape[0], -1))
                target_reg.append(target_reg_i.view(target_reg_i.shape[0], -1, 4))
            target_cls = torch.cat(target_cls, dim=1) # L(b, n)
            target_reg = torch.cat(target_reg, dim=1) # F(b, n, 4)
            # get loss
            m_negpos = target_cls > -1 # B(b, n)
            m_pos    = target_cls > 0  # B(b, n)
            num_pos = torch.sum(m_pos, dim=1).clamp_(min=1) # L(b)
            loss = []
            for b in range(locs.shape[0]):
                pred_cls_b = pred_cls[b][m_negpos[b]]     # F(S+-, classes)
                target_cls_b = target_cls[b][m_negpos[b]] # L(S+-)
                pred_reg_b = pred_reg[b][m_pos[b]]        # F(S+, 4)
                target_reg_b = target_reg[b][m_pos[b]]    # F(S+, 4)
                loss_cls_b = sigmoid_focal_loss(pred_cls_b, target_cls_b, 2.0, 0.25).sum().view(1)
                loss_reg_b = iou_loss(pred_reg_b, target_reg_b).sum().view(1)
                loss.append((loss_cls_b + loss_reg_b) / float(num_pos[b]))
            return torch.cat(loss, dim=0) # F(b)
        else:
            return self._decode(pred_cls, pred_reg, locs)
Exemplo n.º 2
0
    def forward(self, x, loc, label_class=None, label_box=None):
        '''
        Param:
        x:           FloatTensor(batch_num, 3, H, W)
        loc:         FloatTensor(batch_num, 4)
        label_class: LongTensor (batch_num, N_max) or None
        label_box:   FloatTensor(batch_num, N_max, 4) or None

        Return 1:
        loss: FloatTensor(batch_num)

        Return 2:
        cls_i_preds: LongTensor (batch_num, topk)
        cls_p_preds: FloatTensor(batch_num, topk)
        reg_preds:   FloatTensor(batch_num, topk, 4)
        '''

        C3, C4, C5 = self.backbone(x)
        P5 = self.prj_5(C5)
        P4 = self.prj_4(C4)
        P3 = self.prj_3(C3)
        P4 = P4 + self.upsample(P5)
        P3 = P3 + self.upsample(P4)
        P3 = self.conv_3(P3)
        P4 = self.conv_4(P4)
        P5 = self.conv_5(P5)
        P6 = self.conv_out6(C5)
        P7 = self.conv_out7(self.relu(P6))
        pred_list = [P3, P4, P5, P6, P7]
        assert len(pred_list) == len(self.regions) - 1
        cls_out = []
        reg_out = []
        for i, item in enumerate(pred_list):
            cls_i = self.conv_cls(item)
            reg_i = (self.conv_reg(item) * self.scale_div[i]).exp()
            cls_i = cls_i.permute(0, 2, 3, 1).contiguous()
            reg_i = reg_i.permute(0, 2, 3, 1).contiguous()
            cls_i = cls_i.view(cls_i.shape[0], -1, self.classes)
            reg_i = reg_i.view(reg_i.shape[0], -1, 4)
            # cls_i: [b, classes, H, W] -> [b, H*W, classes]
            # reg_i: [b, 4, H, W] -> [b, H*W, 4]
            cls_out.append(cls_i)
            reg_out.append(reg_i)
        cls_out = torch.cat(cls_out, dim=1)
        reg_out = torch.cat(reg_out, dim=1)
        # cls_out[b, shw, classes]
        # reg_out[b, shw, 4]
        if (label_class is not None) and (label_box is not None):
            targets_cls, targets_reg = self._encode(label_class, label_box,
                                                    loc)
            mask_cls = targets_cls > -1  # (b, shw)
            mask_reg = targets_cls > 0  # (b, shw)
            num_pos = torch.sum(mask_reg, dim=1).clamp_(min=1)  # (b)
            loss = []
            for b in range(targets_cls.shape[0]):
                reg_out[b] = distance2bbox(self.a_center_yx, reg_out[b])
                targets_reg[b] = distance2bbox(self.a_center_yx,
                                               targets_reg[b])
                cls_out_b = cls_out[b][mask_cls[b]]  # (S+-, classes)
                reg_out_b = reg_out[b][mask_reg[b]]  # (S+, 4)
                targets_cls_b = targets_cls[b][mask_cls[b]]  # (S+-)
                targets_reg_b = targets_reg[b][mask_reg[b]]  # (S+, 4)
                loss_cls_b = sigmoid_focal_loss(cls_out_b, targets_cls_b, 2.0,
                                                0.25).sum().view(1)
                loss_reg_b = iou_loss(reg_out_b, targets_reg_b).sum().view(1)
                loss.append((loss_cls_b + loss_reg_b) / float(num_pos[b]))
            return torch.cat(loss, dim=0)  # (b)
        else:
            return self._decode(cls_out, reg_out, loc)