def __getitem__(self, index): img, annot = super().__getitem__(index) annot = [o for o in annot if o['iscrowd'] == 0] boxes = [o['bbox'] for o in annot] boxes = torch.as_tensor(boxes).reshape(-1, 4) target = BoxList(boxes, img.size, mode='xywh').convert('xyxy') classes = [o['category_id'] for o in annot] classes = [self.category2id[c] for c in classes] classes = torch.tensor(classes) target.fields['labels'] = classes target.clip(remove_empty=True) if self.transform is not None: img, target = self.transform(img, target) return img, target, index
def forward_single_feature_map(self, location, cls_pred, box_pred, center_pred, image_sizes): batch, channel, height, width = cls_pred.shape cls_pred = cls_pred.view(batch, channel, height, width).permute(0, 2, 3, 1) cls_pred = cls_pred.reshape(batch, -1, channel).sigmoid() box_pred = box_pred.view(batch, 4, height, width).permute(0, 2, 3, 1) box_pred = box_pred.reshape(batch, -1, 4) center_pred = center_pred.view(batch, 1, height, width).permute(0, 2, 3, 1) center_pred = center_pred.reshape(batch, -1).sigmoid() candid_ids = cls_pred > self.threshold top_ns = candid_ids.view(batch, -1).sum(1) top_ns = top_ns.clamp(max=self.top_n) cls_pred = cls_pred * center_pred[:, :, None] results = [] for i in range(batch): cls_p = cls_pred[i] candid_id = candid_ids[i] cls_p = cls_p[candid_id] candid_nonzero = candid_id.nonzero() box_loc = candid_nonzero[:, 0] class_id = candid_nonzero[:, 1] + 1 box_p = box_pred[i] box_p = box_p[box_loc] loc = location[box_loc] top_n = top_ns[i] if candid_id.sum().item() > top_n.item(): cls_p, top_k_id = cls_p.topk(top_n, sorted=False) class_id = class_id[top_k_id] box_p = box_p[top_k_id] loc = loc[top_k_id] detections = torch.stack( [ loc[:, 0] - box_p[:, 0], loc[:, 1] - box_p[:, 1], loc[:, 0] + box_p[:, 2], loc[:, 1] + box_p[:, 3], ], 1, ) height, width = image_sizes[i] boxlist = BoxList(detections, (int(width), int(height)), mode='xyxy') boxlist.fields['labels'] = class_id boxlist.fields['scores'] = torch.sqrt(cls_p) boxlist = boxlist.clip(remove_empty=False) boxlist = remove_small_box(boxlist, self.min_size) results.append(boxlist) return results