def regress_rotation(
     self,
     rotation_gt: Tensor,
     encoding: Tensor,
     loss: Tensor,
     log: dict,
     pred_gt: Tensor,
 ) -> Tensor:
     rotation_pred = self.rotation_head(encoding)
     loss_rotation = L1Loss()(rotation_gt, rotation_pred)
     loss += loss_rotation / torch.exp(
         self.log_sigma_rotate) + self.log_sigma_rotate
     log.update({
         "loss_rotation": loss_rotation.detach(),
         "sigma_rotation": torch.exp(self.log_sigma_rotate).detach(),
     })
     pred_gt.update({"rotation": [rotation_gt, rotation_pred]})
     return loss
 def regress_scaling(
     self,
     scale_gt: Tensor,
     encoding: Tensor,
     loss: Tensor,
     log: dict,
     pred_gt: Tensor,
 ) -> Tensor:
     scale_pred = self.scale_head(encoding)
     loss_scale = L1Loss()(scale_gt, scale_pred)
     loss += loss_scale / torch.exp(
         self.log_sigma_scale) + self.log_sigma_scale
     log.update({
         "loss_scale": loss_scale.detach(),
         "sigma_scale": torch.exp(self.log_sigma_scale).detach(),
     })
     pred_gt.update({"scale": [scale_gt, scale_pred]})
     return loss
 def regress_jittering(
     self,
     jitter_gt: Tensor,
     encoding: Tensor,
     loss: Tensor,
     log: Dict[str, Tensor],
     pred_gt: Tensor,
 ) -> Tensor:
     jitter_pred = self.jitter_head(encoding)
     loss_jitter = L1Loss()(jitter_gt, jitter_pred)
     loss += loss_jitter / torch.exp(
         self.log_sigma_jitter) + self.log_sigma_jitter
     log.update({
         "loss_jitter": loss_jitter.detach(),
         "sigma_jitter": torch.exp(self.log_sigma_jitter).detach(),
     })
     pred_gt.update({"jitter": [jitter_gt, jitter_pred]})
     return loss