Пример #1
0
 def _get_cost_matrix(self, out, tgt):
     out_prob, out_bbox = out["pred_scores"], out["pred_boxes"]
     tgt_ids, tgt_bbox = tgt["labels"], tgt["boxes"]
     cost_class = -out_prob[:, tgt_ids]
     cost_bbox = torch.cdist(out_bbox, tgt_bbox, p=1)
     cost_giou = -generalized_box_iou(box_cxcyczwhd_to_xyzxyz(out_bbox),
                                      box_cxcyczwhd_to_xyzxyz(tgt_bbox))
     cost = (self.cost_bbox * cost_bbox + self.cost_class * cost_class +
             self.cost_giou * cost_giou)
     return cost
Пример #2
0
    def loss_boxes(self, outputs, targets, indices, num_boxes):
        assert "pred_boxes" in outputs
        # print('------ outputs ---------')
        # print(outputs['pred_logits'].shape)
        idx = self._get_src_permutation_idx(indices)
        src_boxes = outputs["pred_boxes"][idx]
        target_boxes = torch.cat(
            [t["boxes"][i] for t, (_, i) in zip(targets, indices)], dim=0)

        loss_bbox = F.l1_loss(src_boxes, target_boxes, reduction="none")

        losses = {}
        losses["loss_bbox"] = loss_bbox.sum() / (num_boxes * 4)

        if "loss_giou" in self.weight_dict:
            loss_giou = 1 - torch.diag(
                box_ops.generalized_box_iou(
                    box_ops.box_cxcyczwhd_to_xyzxyz(src_boxes),
                    box_ops.box_cxcyczwhd_to_xyzxyz(target_boxes),
                ))
            losses["loss_giou"] = loss_giou.sum() / num_boxes
        return losses
Пример #3
0
    def forward(self, outputs, targets):
        out_logits, out_bbox = outputs["pred_logits"], outputs["pred_boxes"]

        # convert to [x0, y0, x1, y1, z0, z1] format
        boxes = []
        field = "orig_size" if self.rescale_to_orig_size else "size"
        out_bbox = box_ops.box_cxcyczwhd_to_xyzxyz(out_bbox)
        for b, t in zip(out_bbox, targets):
            img_d, img_h, img_w = t[field].tolist()
            b = b * torch.tensor([img_w, img_h, img_d, img_w, img_h, img_d],
                                 dtype=torch.float32,
                                 device=b.device)
            boxes.append(b)

        prob = F.softmax(out_logits, -1)
        scores, labels = prob[..., :-1].max(-1)
        results = [{
            "scores": s,
            "labels": l,
            "boxes": b
        } for s, l, b in zip(scores, labels, boxes)]

        if "pred_masks" in outputs:
            max_h = max([tgt["size"][0] for tgt in targets])
            max_w = max([tgt["size"][1] for tgt in targets])
            outputs_masks = outputs["pred_masks"]
            outputs_masks = outputs_masks.squeeze(2)
            outputs_masks = F.interpolate(outputs_masks,
                                          size=(max_h, max_w),
                                          mode="bilinear",
                                          align_corners=False).sigmoid()
            outputs_masks = (outputs_masks >
                             self.threshold).byte().cpu().detach()

            out_masks = outputs_masks
            for i, (cur_mask, t) in enumerate(zip(out_masks, targets)):
                img_h, img_w = t["size"][0], t["size"][1]
                results[i]["masks"] = cur_mask[:, :img_h, :img_w].unsqueeze(1)
                if self.rescale_to_orig_size:
                    results[i]["masks"] = F.interpolate(
                        results[i]["masks"].float(),
                        size=tuple(t["orig_size"].tolist()),
                        mode="nearest",
                    ).byte()

        return results