Exemple #1
0
    def forward(self, x, loc, label_class=None, label_box=None):
        '''
        Param:
        x:           FloatTensor(batch_num, 3, H, W)
        loc:         FloatTensor(batch_num, 4)
        label_class: LongTensor(batch_num, N_max) or None
        label_box:   FloatTensor(batch_num, N_max, 4) or None

        Return 1:
        loss: FloatTensor(batch_num)

        Return 2:
        cls_i_preds: LongTensor(batch_num, topk)
        cls_p_preds: FloatTensor(batch_num, topk)
        reg_preds:   FloatTensor(batch_num, topk, 4)
        '''

        C3, C4, C5 = self.backbone(x)

        P5 = self.prj_5(C5)
        P4 = self.prj_4(C4)
        P3 = self.prj_3(C3)

        P4 = P4 + self.upsample(P5)
        P3 = P3 + self.upsample(P4)

        P3 = self.conv_3(P3)
        P4 = self.conv_4(P4)
        P5 = self.conv_5(P5)

        P6 = self.conv_out6(C5)
        P7 = self.conv_out7(self.relu(P6))

        if self.balanced_fpn:
            # kernel_size, stride, padding, dilation, False, False
            P3 = F.max_pool2d(P3, 3, 2, 1, 1, False, False)
            P5 = self.upsample(P5)
            P4 = (P3 + P4 + P5) / 3.0
            # kernel_size, stride, padding, False, False
            P5 = F.avg_pool2d(P4, 3, 2, 1, False, False)
            P3 = self.upsample(P4)

        pred_list = [P3, P4, P5, P6, P7]
        assert len(pred_list) == self.scales

        cls_out = []
        reg_out = []
        for item in pred_list:
            cls_i = self.conv_cls(item)
            reg_i = self.conv_reg(item)
            # cls_i: [b, an*classes, H, W] -> [b, H*W*an, classes]
            cls_i = cls_i.permute(0, 2, 3, 1).contiguous()
            cls_i = cls_i.view(cls_i.shape[0], -1, self.classes)
            # reg_i: [b, an*4, H, W] -> [b, H*W*an, 4]
            reg_i = reg_i.permute(0, 2, 3, 1).contiguous()
            reg_i = reg_i.view(reg_i.shape[0], -1, 4)
            cls_out.append(cls_i)
            reg_out.append(reg_i)

        # cls_out[b, hwan, classes]
        # reg_out[b, hwan, 4]
        cls_out = torch.cat(cls_out, dim=1)
        reg_out = torch.cat(reg_out, dim=1)

        if (label_class is not None) and (label_box is not None):
            targets_cls, targets_reg = self._encode(
                label_class, label_box, loc)  # (b, hwan), (b, hwan, 4)
            mask_cls = targets_cls > -1  # (b, hwan)
            mask_reg = targets_cls > 0  # (b, hwan)
            num_pos = torch.sum(mask_reg, dim=1).clamp_(min=self.scales)  # (b)
            loss = []
            for b in range(targets_cls.shape[0]):
                cls_out_b = cls_out[b][mask_cls[b]]  # (S+-, classes)
                reg_out_b = reg_out[b][mask_reg[b]]  # (S+, 4)
                targets_cls_b = targets_cls[b][mask_cls[b]]  # (S+-)
                targets_reg_b = targets_reg[b][mask_reg[b]]  # # (S+, 4)
                loss_cls_b = sigmoid_focal_loss(cls_out_b, targets_cls_b, 2.0,
                                                0.25).sum().view(1)
                loss_reg_b = smooth_l1_loss(reg_out_b, targets_reg_b,
                                            0.11).sum().view(1)
                loss.append((loss_cls_b + loss_reg_b) / float(num_pos[b]))
            return torch.cat(loss, dim=0)  # (b)
        else:
            return self._decode(cls_out, reg_out, loc)
Exemple #2
0
    def forward(self, imgs, locs, label_class=None, label_box=None):
        '''
        Param:
        imgs:        F(b, 3, vsz, vsz)
        locs:        F(b, 4)
        label_class: L(b, N_max) or None
        label_box:   F(b, N_max, 4) or None

        Return 1:
        loss:        F(b)

        Return 2:
        pred_cls_i:  L(b, topk)
        pred_cls_p:  F(b, topk)
        pred_reg:    F(b, topk, 4)
        '''
        
        # forward fpn
        C3, C4, C5 = self.backbone(imgs)
        P5 = self.prj_5(C5)
        P4 = self.prj_4(C4)
        P3 = self.prj_3(C3)
        P4 = P4 + self.upsample(P5)
        P3 = P3 + self.upsample(P4)
        P3 = self.conv_3(P3)
        P4 = self.conv_4(P4)
        P5 = self.conv_5(P5)
        P6 = self.conv_out6(C5)
        P7 = self.conv_out7(self.relu(P6))
        pred_list = [P3, P4, P5, P6, P7]
        assert len(pred_list) == len(self.r)

        # get pred
        pred_cls = []
        pred_reg = []
        for i, feature in enumerate(pred_list):
            cls_i = self.conv_cls(feature)
            reg_i =  self.conv_reg(feature) * self.scale_param[i]
            cls_i = cls_i.permute(0,2,3,1).contiguous()
            reg_i = reg_i.permute(0,2,3,1).contiguous() # b, ph, pw, 4
            reg_i = self.decode_box(reg_i, self.view_size, self.view_size, self.phpw[i][0], self.phpw[i][1])
            pred_cls.append(cls_i.view(cls_i.shape[0], -1, self.classes))
            pred_reg.append(reg_i.view(reg_i.shape[0], -1, 4))
        pred_cls = torch.cat(pred_cls, dim=1)
        pred_reg = torch.cat(pred_reg, dim=1)
        # pred_cls: F(b, n, classes)
        # pred_reg: F(b, n, 4)

        if (label_class is not None) and (label_box is not None):
            # <= 200
            n_max = min(label_class.shape[1], 200)
            if n_max == 200:
                label_class = label_class[:, :200]
                label_box   = label_box[:, :200, :]
            # get target
            target_cls = []
            target_reg = []
            for i in range(len(self.r)):
                target_cls_i, target_reg_i = assign_box(label_class, label_box, locs,
                    self.view_size, self.view_size, self.phpw[i][0], self.phpw[i][1],
                    self.tlbr_max_minmax[i][0], self.tlbr_max_minmax[i][1], self.r[i])
                target_cls.append(target_cls_i.view(target_cls_i.shape[0], -1))
                target_reg.append(target_reg_i.view(target_reg_i.shape[0], -1, 4))
            target_cls = torch.cat(target_cls, dim=1) # L(b, n)
            target_reg = torch.cat(target_reg, dim=1) # F(b, n, 4)
            # get loss
            m_negpos = target_cls > -1 # B(b, n)
            m_pos    = target_cls > 0  # B(b, n)
            num_pos = torch.sum(m_pos, dim=1).clamp_(min=1) # L(b)
            loss = []
            for b in range(locs.shape[0]):
                pred_cls_b = pred_cls[b][m_negpos[b]]     # F(S+-, classes)
                target_cls_b = target_cls[b][m_negpos[b]] # L(S+-)
                pred_reg_b = pred_reg[b][m_pos[b]]        # F(S+, 4)
                target_reg_b = target_reg[b][m_pos[b]]    # F(S+, 4)
                loss_cls_b = sigmoid_focal_loss(pred_cls_b, target_cls_b, 2.0, 0.25).sum().view(1)
                loss_reg_b = iou_loss(pred_reg_b, target_reg_b).sum().view(1)
                loss.append((loss_cls_b + loss_reg_b) / float(num_pos[b]))
            return torch.cat(loss, dim=0) # F(b)
        else:
            return self._decode(pred_cls, pred_reg, locs)
Exemple #3
0
    def forward(self, x, loc, label_class=None, label_box=None):
        '''
        Param:
        x:           FloatTensor(batch_num, 3, H, W)
        loc:         FloatTensor(batch_num, 4)
        label_class: LongTensor(batch_num, N_max) or None
        label_box:   FloatTensor(batch_num, N_max, 4) or None

        Return 1:
        loss: FloatTensor(batch_num)

        Return 2:
        cls_i_preds: LongTensor(batch_num, topk)
        cls_p_preds: FloatTensor(batch_num, topk)
        reg_preds:   FloatTensor(batch_num, topk, 4)
        '''
        w_2_in = self.w_relu(self.w_2_in)
        w_3_in = self.w_relu(self.w_3_in)
        w_2_in /= torch.sum(w_2_in, dim=0) + self.eps
        w_3_in /= torch.sum(w_3_in, dim=0) + self.eps

        C3, C4, C5 = self.backbone(x)

        D5 = self.bi1_prj1_5(C5)
        D5 = self.bi1_prj1_5bn(D5)
        D4 = self.bi1_prj1_4(C4)
        D4 = self.bi1_prj1_4bn(D4)
        D3 = self.bi1_prj1_3(C3)
        D3 = self.bi1_prj1_3bn(D3)

        E4 = self.bi1_prj2_4(
            (w_2_in[0][0] * D4 + w_2_in[1][0] * self.upsample(D5)))

        F3 = self.bi1_prj3_3(
            (w_2_in[0][1] * D3 + w_2_in[1][1] * self.upsample(E4)))
        F4 = self.bi1_prj3_4((w_3_in[0][0] * self.downsample(F3) +
                              w_3_in[1][0] * E4 + w_3_in[2][0] * D4))
        F5 = self.bi1_prj3_5(
            (w_2_in[0][2] * D5 + w_2_in[1][2] * self.downsample(F4)))

        G4 = self.bi2_prj2_4(
            (w_2_in[0][3] * F4 + w_2_in[1][3] * self.upsample(F5)))

        H3 = self.bi2_prj3_3(
            (w_2_in[0][4] * F3 + w_2_in[1][4] * self.upsample(G4)))
        H4 = self.bi2_prj3_4((w_3_in[0][1] * self.downsample(H3) +
                              w_3_in[1][1] * G4 + w_3_in[2][1] * F4))
        H5 = self.bi2_prj3_5(
            (w_2_in[0][5] * F5 + w_2_in[1][5] * self.downsample(H4)))

        P6 = self.conv_out6(C5)
        P7 = self.conv_out7(self.relu(P6))

        P3, P4, P5, P6, P7 = H3, H4, H5, P6, P7

        # log = [w_2_in.data.cpu().numpy(), w_3_in.data.cpu().numpy()]
        # np.save('./bifpn_weight/weight_log', log)
        # print('w_2_in = {0}, w_3_in={1}'.format(w_2_in.data.cpu().numpy(), w_3_in.data.cpu().numpy()))

        if self.balanced_fpn:
            # kernel_size, stride, padding, dilation, False, False
            P3 = F.max_pool2d(P3, 3, 2, 1, 1, False, False)
            P5 = self.upsample(P5)
            P4 = (P3 + P4 + P5) / 3.0
            # kernel_size, stride, padding, False, False
            P5 = F.avg_pool2d(P4, 3, 2, 1, False, False)
            P3 = self.upsample(P4)

        pred_list = [P3, P4, P5, P6, P7]
        assert len(pred_list) == self.scales
        # assert len(pred_list) == len(self.scales)

        cls_out = []
        reg_out = []
        for item in pred_list:
            cls_i = self.conv_cls(item)
            reg_i = self.conv_reg(item)
            # cls_i: [b, an*classes, H, W] -> [b, H*W*an, classes]
            cls_i = cls_i.permute(0, 2, 3, 1).contiguous()
            cls_i = cls_i.view(cls_i.shape[0], -1, self.classes)
            # reg_i: [b, an*4, H, W] -> [b, H*W*an, 4]
            reg_i = reg_i.permute(0, 2, 3, 1).contiguous()
            reg_i = reg_i.view(reg_i.shape[0], -1, 4)
            cls_out.append(cls_i)
            reg_out.append(reg_i)

        # cls_out[b, hwan, classes]
        # reg_out[b, hwan, 4]
        cls_out = torch.cat(cls_out, dim=1)
        reg_out = torch.cat(reg_out, dim=1)

        if (label_class is not None) and (label_box is not None):
            targets_cls, targets_reg = self._encode(
                label_class, label_box, loc)  # (b, hwan), (b, hwan, 4)
            mask_cls = targets_cls > -1  # (b, hwan)
            mask_reg = targets_cls > 0  # (b, hwan)
            num_pos = torch.sum(mask_reg, dim=1).clamp_(min=self.scales)  # (b)
            # num_pos = torch.sum(mask_reg, dim=1).clamp_(min=len(self.scales)) # (b)
            loss = []
            for b in range(targets_cls.shape[0]):
                cls_out_b = cls_out[b][mask_cls[b]]  # (S+-, classes)
                reg_out_b = reg_out[b][mask_reg[b]]  # (S+, 4)
                targets_cls_b = targets_cls[b][mask_cls[b]]  # (S+-)
                targets_reg_b = targets_reg[b][mask_reg[b]]  # # (S+, 4)
                loss_cls_b = sigmoid_focal_loss(cls_out_b, targets_cls_b, 2.0,
                                                0.25).sum().view(1)
                loss_reg_b = smooth_l1_loss(reg_out_b, targets_reg_b,
                                            0.11).sum().view(1)
                loss.append((loss_cls_b + loss_reg_b) / float(num_pos[b]))
            return torch.cat(loss, dim=0)  # (b
        else:
            return self._decode(cls_out, reg_out, loc)
Exemple #4
0
    def forward(self, x, loc, label_class=None, label_box=None):
        '''
        Param:
        x:           FloatTensor(batch_num, 3, H, W)
        loc:         FloatTensor(batch_num, 4)
        label_class: LongTensor (batch_num, N_max) or None
        label_box:   FloatTensor(batch_num, N_max, 4) or None

        Return 1:
        loss: FloatTensor(batch_num)

        Return 2:
        cls_i_preds: LongTensor (batch_num, topk)
        cls_p_preds: FloatTensor(batch_num, topk)
        reg_preds:   FloatTensor(batch_num, topk, 4)
        '''

        C3, C4, C5 = self.backbone(x)
        P5 = self.prj_5(C5)
        P4 = self.prj_4(C4)
        P3 = self.prj_3(C3)
        P4 = P4 + self.upsample(P5)
        P3 = P3 + self.upsample(P4)
        P3 = self.conv_3(P3)
        P4 = self.conv_4(P4)
        P5 = self.conv_5(P5)
        P6 = self.conv_out6(C5)
        P7 = self.conv_out7(self.relu(P6))
        pred_list = [P3, P4, P5, P6, P7]
        assert len(pred_list) == len(self.regions) - 1
        cls_out = []
        reg_out = []
        for i, item in enumerate(pred_list):
            cls_i = self.conv_cls(item)
            reg_i = (self.conv_reg(item) * self.scale_div[i]).exp()
            cls_i = cls_i.permute(0, 2, 3, 1).contiguous()
            reg_i = reg_i.permute(0, 2, 3, 1).contiguous()
            cls_i = cls_i.view(cls_i.shape[0], -1, self.classes)
            reg_i = reg_i.view(reg_i.shape[0], -1, 4)
            # cls_i: [b, classes, H, W] -> [b, H*W, classes]
            # reg_i: [b, 4, H, W] -> [b, H*W, 4]
            cls_out.append(cls_i)
            reg_out.append(reg_i)
        cls_out = torch.cat(cls_out, dim=1)
        reg_out = torch.cat(reg_out, dim=1)
        # cls_out[b, shw, classes]
        # reg_out[b, shw, 4]
        if (label_class is not None) and (label_box is not None):
            targets_cls, targets_reg = self._encode(label_class, label_box,
                                                    loc)
            mask_cls = targets_cls > -1  # (b, shw)
            mask_reg = targets_cls > 0  # (b, shw)
            num_pos = torch.sum(mask_reg, dim=1).clamp_(min=1)  # (b)
            loss = []
            for b in range(targets_cls.shape[0]):
                reg_out[b] = distance2bbox(self.a_center_yx, reg_out[b])
                targets_reg[b] = distance2bbox(self.a_center_yx,
                                               targets_reg[b])
                cls_out_b = cls_out[b][mask_cls[b]]  # (S+-, classes)
                reg_out_b = reg_out[b][mask_reg[b]]  # (S+, 4)
                targets_cls_b = targets_cls[b][mask_cls[b]]  # (S+-)
                targets_reg_b = targets_reg[b][mask_reg[b]]  # (S+, 4)
                loss_cls_b = sigmoid_focal_loss(cls_out_b, targets_cls_b, 2.0,
                                                0.25).sum().view(1)
                loss_reg_b = iou_loss(reg_out_b, targets_reg_b).sum().view(1)
                loss.append((loss_cls_b + loss_reg_b) / float(num_pos[b]))
            return torch.cat(loss, dim=0)  # (b)
        else:
            return self._decode(cls_out, reg_out, loc)