def _get_cost_matrix(self, out, tgt): out_prob, out_bbox = out["pred_scores"], out["pred_boxes"] tgt_ids, tgt_bbox = tgt["labels"], tgt["boxes"] cost_class = -out_prob[:, tgt_ids] cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1) cost_giou = -generalized_box_iou(box_cxcyczwhd_to_xyzxyz(out_bbox), box_cxcyczwhd_to_xyzxyz(tgt_bbox)) cost = (self.cost_bbox * cost_bbox + self.cost_class * cost_class + self.cost_giou * cost_giou) return cost
def loss_boxes(self, outputs, targets, indices, num_boxes): assert "pred_boxes" in outputs # print('------ outputs ---------') # print(outputs['pred_logits'].shape) idx = self._get_src_permutation_idx(indices) src_boxes = outputs["pred_boxes"][idx] target_boxes = torch.cat( [t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0) loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none") losses = {} losses["loss_bbox"] = loss_bbox.sum() / (num_boxes * 4) if "loss_giou" in self.weight_dict: loss_giou = 1 - torch.diag( box_ops.generalized_box_iou( box_ops.box_cxcyczwhd_to_xyzxyz(src_boxes), box_ops.box_cxcyczwhd_to_xyzxyz(target_boxes), )) losses["loss_giou"] = loss_giou.sum() / num_boxes return losses
def forward(self, outputs, targets): out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"] # convert to [x0, y0, x1, y1, z0, z1] format boxes = [] field = "orig_size" if self.rescale_to_orig_size else "size" out_bbox = box_ops.box_cxcyczwhd_to_xyzxyz(out_bbox) for b, t in zip(out_bbox, targets): img_d, img_h, img_w = t[field].tolist() b = b * torch.tensor([img_w, img_h, img_d, img_w, img_h, img_d], dtype=torch.float32, device=b.device) boxes.append(b) prob = F.softmax(out_logits, -1) scores, labels = prob[..., :-1].max(-1) results = [{ "scores": s, "labels": l, "boxes": b } for s, l, b in zip(scores, labels, boxes)] if "pred_masks" in outputs: max_h = max([tgt["size"][0] for tgt in targets]) max_w = max([tgt["size"][1] for tgt in targets]) outputs_masks = outputs["pred_masks"] outputs_masks = outputs_masks.squeeze(2) outputs_masks = F.interpolate(outputs_masks, size=(max_h, max_w), mode="bilinear", align_corners=False).sigmoid() outputs_masks = (outputs_masks > self.threshold).byte().cpu().detach() out_masks = outputs_masks for i, (cur_mask, t) in enumerate(zip(out_masks, targets)): img_h, img_w = t["size"][0], t["size"][1] results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1) if self.rescale_to_orig_size: results[i]["masks"] = F.interpolate( results[i]["masks"].float(), size=tuple(t["orig_size"].tolist()), mode="nearest", ).byte() return results