def forward(self, x, loc, label_class=None, label_box=None): ''' Param: x: FloatTensor(batch_num, 3, H, W) loc: FloatTensor(batch_num, 4) label_class: LongTensor(batch_num, N_max) or None label_box: FloatTensor(batch_num, N_max, 4) or None Return 1: loss: FloatTensor(batch_num) Return 2: cls_i_preds: LongTensor(batch_num, topk) cls_p_preds: FloatTensor(batch_num, topk) reg_preds: FloatTensor(batch_num, topk, 4) ''' C3, C4, C5 = self.backbone(x) P5 = self.prj_5(C5) P4 = self.prj_4(C4) P3 = self.prj_3(C3) P4 = P4 + self.upsample(P5) P3 = P3 + self.upsample(P4) P3 = self.conv_3(P3) P4 = self.conv_4(P4) P5 = self.conv_5(P5) P6 = self.conv_out6(C5) P7 = self.conv_out7(self.relu(P6)) if self.balanced_fpn: # kernel_size, stride, padding, dilation, False, False P3 = F.max_pool2d(P3, 3, 2, 1, 1, False, False) P5 = self.upsample(P5) P4 = (P3 + P4 + P5) / 3.0 # kernel_size, stride, padding, False, False P5 = F.avg_pool2d(P4, 3, 2, 1, False, False) P3 = self.upsample(P4) pred_list = [P3, P4, P5, P6, P7] assert len(pred_list) == self.scales cls_out = [] reg_out = [] for item in pred_list: cls_i = self.conv_cls(item) reg_i = self.conv_reg(item) # cls_i: [b, an*classes, H, W] -> [b, H*W*an, classes] cls_i = cls_i.permute(0, 2, 3, 1).contiguous() cls_i = cls_i.view(cls_i.shape[0], -1, self.classes) # reg_i: [b, an*4, H, W] -> [b, H*W*an, 4] reg_i = reg_i.permute(0, 2, 3, 1).contiguous() reg_i = reg_i.view(reg_i.shape[0], -1, 4) cls_out.append(cls_i) reg_out.append(reg_i) # cls_out[b, hwan, classes] # reg_out[b, hwan, 4] cls_out = torch.cat(cls_out, dim=1) reg_out = torch.cat(reg_out, dim=1) if (label_class is not None) and (label_box is not None): targets_cls, targets_reg = self._encode( label_class, label_box, loc) # (b, hwan), (b, hwan, 4) mask_cls = targets_cls > -1 # (b, hwan) mask_reg = targets_cls > 0 # (b, hwan) num_pos = torch.sum(mask_reg, dim=1).clamp_(min=self.scales) # (b) loss = [] for b in range(targets_cls.shape[0]): cls_out_b = cls_out[b][mask_cls[b]] # (S+-, classes) reg_out_b = reg_out[b][mask_reg[b]] # (S+, 4) targets_cls_b = targets_cls[b][mask_cls[b]] # (S+-) targets_reg_b = targets_reg[b][mask_reg[b]] # # (S+, 4) loss_cls_b = sigmoid_focal_loss(cls_out_b, targets_cls_b, 2.0, 0.25).sum().view(1) loss_reg_b = smooth_l1_loss(reg_out_b, targets_reg_b, 0.11).sum().view(1) loss.append((loss_cls_b + loss_reg_b) / float(num_pos[b])) return torch.cat(loss, dim=0) # (b) else: return self._decode(cls_out, reg_out, loc)
def forward(self, imgs, locs, label_class=None, label_box=None): ''' Param: imgs: F(b, 3, vsz, vsz) locs: F(b, 4) label_class: L(b, N_max) or None label_box: F(b, N_max, 4) or None Return 1: loss: F(b) Return 2: pred_cls_i: L(b, topk) pred_cls_p: F(b, topk) pred_reg: F(b, topk, 4) ''' # forward fpn C3, C4, C5 = self.backbone(imgs) P5 = self.prj_5(C5) P4 = self.prj_4(C4) P3 = self.prj_3(C3) P4 = P4 + self.upsample(P5) P3 = P3 + self.upsample(P4) P3 = self.conv_3(P3) P4 = self.conv_4(P4) P5 = self.conv_5(P5) P6 = self.conv_out6(C5) P7 = self.conv_out7(self.relu(P6)) pred_list = [P3, P4, P5, P6, P7] assert len(pred_list) == len(self.r) # get pred pred_cls = [] pred_reg = [] for i, feature in enumerate(pred_list): cls_i = self.conv_cls(feature) reg_i = self.conv_reg(feature) * self.scale_param[i] cls_i = cls_i.permute(0,2,3,1).contiguous() reg_i = reg_i.permute(0,2,3,1).contiguous() # b, ph, pw, 4 reg_i = self.decode_box(reg_i, self.view_size, self.view_size, self.phpw[i][0], self.phpw[i][1]) pred_cls.append(cls_i.view(cls_i.shape[0], -1, self.classes)) pred_reg.append(reg_i.view(reg_i.shape[0], -1, 4)) pred_cls = torch.cat(pred_cls, dim=1) pred_reg = torch.cat(pred_reg, dim=1) # pred_cls: F(b, n, classes) # pred_reg: F(b, n, 4) if (label_class is not None) and (label_box is not None): # <= 200 n_max = min(label_class.shape[1], 200) if n_max == 200: label_class = label_class[:, :200] label_box = label_box[:, :200, :] # get target target_cls = [] target_reg = [] for i in range(len(self.r)): target_cls_i, target_reg_i = assign_box(label_class, label_box, locs, self.view_size, self.view_size, self.phpw[i][0], self.phpw[i][1], self.tlbr_max_minmax[i][0], self.tlbr_max_minmax[i][1], self.r[i]) target_cls.append(target_cls_i.view(target_cls_i.shape[0], -1)) target_reg.append(target_reg_i.view(target_reg_i.shape[0], -1, 4)) target_cls = torch.cat(target_cls, dim=1) # L(b, n) target_reg = torch.cat(target_reg, dim=1) # F(b, n, 4) # get loss m_negpos = target_cls > -1 # B(b, n) m_pos = target_cls > 0 # B(b, n) num_pos = torch.sum(m_pos, dim=1).clamp_(min=1) # L(b) loss = [] for b in range(locs.shape[0]): pred_cls_b = pred_cls[b][m_negpos[b]] # F(S+-, classes) target_cls_b = target_cls[b][m_negpos[b]] # L(S+-) pred_reg_b = pred_reg[b][m_pos[b]] # F(S+, 4) target_reg_b = target_reg[b][m_pos[b]] # F(S+, 4) loss_cls_b = sigmoid_focal_loss(pred_cls_b, target_cls_b, 2.0, 0.25).sum().view(1) loss_reg_b = iou_loss(pred_reg_b, target_reg_b).sum().view(1) loss.append((loss_cls_b + loss_reg_b) / float(num_pos[b])) return torch.cat(loss, dim=0) # F(b) else: return self._decode(pred_cls, pred_reg, locs)
def forward(self, x, loc, label_class=None, label_box=None): ''' Param: x: FloatTensor(batch_num, 3, H, W) loc: FloatTensor(batch_num, 4) label_class: LongTensor(batch_num, N_max) or None label_box: FloatTensor(batch_num, N_max, 4) or None Return 1: loss: FloatTensor(batch_num) Return 2: cls_i_preds: LongTensor(batch_num, topk) cls_p_preds: FloatTensor(batch_num, topk) reg_preds: FloatTensor(batch_num, topk, 4) ''' w_2_in = self.w_relu(self.w_2_in) w_3_in = self.w_relu(self.w_3_in) w_2_in /= torch.sum(w_2_in, dim=0) + self.eps w_3_in /= torch.sum(w_3_in, dim=0) + self.eps C3, C4, C5 = self.backbone(x) D5 = self.bi1_prj1_5(C5) D5 = self.bi1_prj1_5bn(D5) D4 = self.bi1_prj1_4(C4) D4 = self.bi1_prj1_4bn(D4) D3 = self.bi1_prj1_3(C3) D3 = self.bi1_prj1_3bn(D3) E4 = self.bi1_prj2_4( (w_2_in[0][0] * D4 + w_2_in[1][0] * self.upsample(D5))) F3 = self.bi1_prj3_3( (w_2_in[0][1] * D3 + w_2_in[1][1] * self.upsample(E4))) F4 = self.bi1_prj3_4((w_3_in[0][0] * self.downsample(F3) + w_3_in[1][0] * E4 + w_3_in[2][0] * D4)) F5 = self.bi1_prj3_5( (w_2_in[0][2] * D5 + w_2_in[1][2] * self.downsample(F4))) G4 = self.bi2_prj2_4( (w_2_in[0][3] * F4 + w_2_in[1][3] * self.upsample(F5))) H3 = self.bi2_prj3_3( (w_2_in[0][4] * F3 + w_2_in[1][4] * self.upsample(G4))) H4 = self.bi2_prj3_4((w_3_in[0][1] * self.downsample(H3) + w_3_in[1][1] * G4 + w_3_in[2][1] * F4)) H5 = self.bi2_prj3_5( (w_2_in[0][5] * F5 + w_2_in[1][5] * self.downsample(H4))) P6 = self.conv_out6(C5) P7 = self.conv_out7(self.relu(P6)) P3, P4, P5, P6, P7 = H3, H4, H5, P6, P7 # log = [w_2_in.data.cpu().numpy(), w_3_in.data.cpu().numpy()] # np.save('./bifpn_weight/weight_log', log) # print('w_2_in = {0}, w_3_in={1}'.format(w_2_in.data.cpu().numpy(), w_3_in.data.cpu().numpy())) if self.balanced_fpn: # kernel_size, stride, padding, dilation, False, False P3 = F.max_pool2d(P3, 3, 2, 1, 1, False, False) P5 = self.upsample(P5) P4 = (P3 + P4 + P5) / 3.0 # kernel_size, stride, padding, False, False P5 = F.avg_pool2d(P4, 3, 2, 1, False, False) P3 = self.upsample(P4) pred_list = [P3, P4, P5, P6, P7] assert len(pred_list) == self.scales # assert len(pred_list) == len(self.scales) cls_out = [] reg_out = [] for item in pred_list: cls_i = self.conv_cls(item) reg_i = self.conv_reg(item) # cls_i: [b, an*classes, H, W] -> [b, H*W*an, classes] cls_i = cls_i.permute(0, 2, 3, 1).contiguous() cls_i = cls_i.view(cls_i.shape[0], -1, self.classes) # reg_i: [b, an*4, H, W] -> [b, H*W*an, 4] reg_i = reg_i.permute(0, 2, 3, 1).contiguous() reg_i = reg_i.view(reg_i.shape[0], -1, 4) cls_out.append(cls_i) reg_out.append(reg_i) # cls_out[b, hwan, classes] # reg_out[b, hwan, 4] cls_out = torch.cat(cls_out, dim=1) reg_out = torch.cat(reg_out, dim=1) if (label_class is not None) and (label_box is not None): targets_cls, targets_reg = self._encode( label_class, label_box, loc) # (b, hwan), (b, hwan, 4) mask_cls = targets_cls > -1 # (b, hwan) mask_reg = targets_cls > 0 # (b, hwan) num_pos = torch.sum(mask_reg, dim=1).clamp_(min=self.scales) # (b) # num_pos = torch.sum(mask_reg, dim=1).clamp_(min=len(self.scales)) # (b) loss = [] for b in range(targets_cls.shape[0]): cls_out_b = cls_out[b][mask_cls[b]] # (S+-, classes) reg_out_b = reg_out[b][mask_reg[b]] # (S+, 4) targets_cls_b = targets_cls[b][mask_cls[b]] # (S+-) targets_reg_b = targets_reg[b][mask_reg[b]] # # (S+, 4) loss_cls_b = sigmoid_focal_loss(cls_out_b, targets_cls_b, 2.0, 0.25).sum().view(1) loss_reg_b = smooth_l1_loss(reg_out_b, targets_reg_b, 0.11).sum().view(1) loss.append((loss_cls_b + loss_reg_b) / float(num_pos[b])) return torch.cat(loss, dim=0) # (b else: return self._decode(cls_out, reg_out, loc)
def forward(self, x, loc, label_class=None, label_box=None): ''' Param: x: FloatTensor(batch_num, 3, H, W) loc: FloatTensor(batch_num, 4) label_class: LongTensor (batch_num, N_max) or None label_box: FloatTensor(batch_num, N_max, 4) or None Return 1: loss: FloatTensor(batch_num) Return 2: cls_i_preds: LongTensor (batch_num, topk) cls_p_preds: FloatTensor(batch_num, topk) reg_preds: FloatTensor(batch_num, topk, 4) ''' C3, C4, C5 = self.backbone(x) P5 = self.prj_5(C5) P4 = self.prj_4(C4) P3 = self.prj_3(C3) P4 = P4 + self.upsample(P5) P3 = P3 + self.upsample(P4) P3 = self.conv_3(P3) P4 = self.conv_4(P4) P5 = self.conv_5(P5) P6 = self.conv_out6(C5) P7 = self.conv_out7(self.relu(P6)) pred_list = [P3, P4, P5, P6, P7] assert len(pred_list) == len(self.regions) - 1 cls_out = [] reg_out = [] for i, item in enumerate(pred_list): cls_i = self.conv_cls(item) reg_i = (self.conv_reg(item) * self.scale_div[i]).exp() cls_i = cls_i.permute(0, 2, 3, 1).contiguous() reg_i = reg_i.permute(0, 2, 3, 1).contiguous() cls_i = cls_i.view(cls_i.shape[0], -1, self.classes) reg_i = reg_i.view(reg_i.shape[0], -1, 4) # cls_i: [b, classes, H, W] -> [b, H*W, classes] # reg_i: [b, 4, H, W] -> [b, H*W, 4] cls_out.append(cls_i) reg_out.append(reg_i) cls_out = torch.cat(cls_out, dim=1) reg_out = torch.cat(reg_out, dim=1) # cls_out[b, shw, classes] # reg_out[b, shw, 4] if (label_class is not None) and (label_box is not None): targets_cls, targets_reg = self._encode(label_class, label_box, loc) mask_cls = targets_cls > -1 # (b, shw) mask_reg = targets_cls > 0 # (b, shw) num_pos = torch.sum(mask_reg, dim=1).clamp_(min=1) # (b) loss = [] for b in range(targets_cls.shape[0]): reg_out[b] = distance2bbox(self.a_center_yx, reg_out[b]) targets_reg[b] = distance2bbox(self.a_center_yx, targets_reg[b]) cls_out_b = cls_out[b][mask_cls[b]] # (S+-, classes) reg_out_b = reg_out[b][mask_reg[b]] # (S+, 4) targets_cls_b = targets_cls[b][mask_cls[b]] # (S+-) targets_reg_b = targets_reg[b][mask_reg[b]] # (S+, 4) loss_cls_b = sigmoid_focal_loss(cls_out_b, targets_cls_b, 2.0, 0.25).sum().view(1) loss_reg_b = iou_loss(reg_out_b, targets_reg_b).sum().view(1) loss.append((loss_cls_b + loss_reg_b) / float(num_pos[b])) return torch.cat(loss, dim=0) # (b) else: return self._decode(cls_out, reg_out, loc)