Esempio n. 1
0
    def test_offset_loss(self):
        pred_offset = torch.tensor([[2, 0, 0], [0, 1, 0]]).float()
        gt_offsets = torch.tensor([[2, 0, 0], [0, 1, 0]]).float()

        losses = offset_loss(pred_offset, gt_offsets, 2)
        self.assertEqual(losses["offset_norm_loss"].item(), 0)
        self.assertAlmostEqual(losses["offset_dir_loss"].item(), -1, places=5)

        gt_offsets = torch.tensor([[2, 0, 0], [0, -1, 0]]).float()

        losses = offset_loss(pred_offset, gt_offsets, 2)
        self.assertAlmostEqual(losses["offset_norm_loss"].item(), (0 + 2) / 2.0, places=5)
        self.assertAlmostEqual(losses["offset_dir_loss"].item(), (-1 + 1) / 2.0, places=5)
Esempio n. 2
0
    def _compute_loss(self):
        # Semantic loss
        self.semantic_loss = torch.nn.functional.nll_loss(
            self.output.semantic_logits, self.labels.y, ignore_index=IGNORE_LABEL
        )
        self.loss = self.opt.loss_weights.semantic * self.semantic_loss

        # Offset loss
        self.input.instance_mask = self.input.instance_mask.to(self.device)
        self.input.vote_label = self.input.vote_label.to(self.device)
        offset_losses = offset_loss(
            self.output.offset_logits[self.input.instance_mask],
            self.input.vote_label[self.input.instance_mask],
            torch.sum(self.input.instance_mask),
        )
        for loss_name, loss in offset_losses.items():
            setattr(self, loss_name, loss)
            self.loss += self.opt.loss_weights[loss_name] * loss

        # Score loss
        if self.output.cluster_scores is not None and self._activate_scorer:
            self.score_loss = instance_iou_loss(
                self.output.clusters,
                self.output.cluster_scores,
                self.input.instance_labels.to(self.device),
                self.input.batch.to(self.device),
                min_iou_threshold=self.opt.min_iou_threshold,
                max_iou_threshold=self.opt.max_iou_threshold,
            )
            self.loss += self.score_loss * self.opt.loss_weights["score_loss"]