示例#1
0
 def _compute_losses(self):
     losses = votenet_module.get_loss(self.input, self.output, self.loss_params)
     for loss_name, loss in losses.items():
         if torch.is_tensor(loss):
             if not self.losses_has_been_added:
                 self.loss_names += [loss_name]
             setattr(self, loss_name, loss)
     self.losses_has_been_added = True
示例#2
0
 def _compute_losses(self):
     if self._weight_classes is not None:
         self._weight_classes = self._weight_classes.to(self.device)
     losses = votenet_module.get_loss(self.input,
                                      self.output,
                                      self.loss_params,
                                      weight_classes=self._weight_classes)
     for loss_name, loss in losses.items():
         if torch.is_tensor(loss):
             if not self.losses_has_been_added:
                 self.loss_names += [loss_name]
             setattr(self, loss_name, loss)
     self.losses_has_been_added = True
示例#3
0
    def _compute_losses(self):
        losses = votenet_module.get_loss(self.input, self.output,
                                         self.loss_params)
        for loss_name, loss in losses.items():
            if torch.is_tensor(loss):
                if not self.losses_has_been_added:
                    self.loss_names += [loss_name]
                setattr(self, loss_name, loss)

        if self.semantic_logits is not None:
            if not self.losses_has_been_added:
                self.loss_names += ["semantic_loss"]
            self.semantic_loss = torch.nn.functional.nll_loss(
                self.semantic_logits,
                self.semantic_labels,
                ignore_index=IGNORE_LABEL)
            self.loss += 10 * self.semantic_loss
        self.losses_has_been_added = True